diff --git a/.cspell.dict/cpython.txt b/.cspell.dict/cpython.txt index b8081e25a9b..a9fbc8f4318 100644 --- a/.cspell.dict/cpython.txt +++ b/.cspell.dict/cpython.txt @@ -127,8 +127,8 @@ NEWLOCALS newsemlockobject nfrees nkwargs -nlocalsplus nkwelts +nlocalsplus Nondescriptor noninteger nops @@ -154,12 +154,14 @@ prec preinitialized pybuilddir pycore +pyinner pydecimal Pyfunc pylifecycle pymain pyrepl PYTHONTRACEMALLOC +PYTHONUTF8 pythonw PYTHREAD_NAME releasebuffer @@ -171,9 +173,11 @@ saveall scls setdict setfunc +setprofileallthreads SETREF setresult setslice +settraceallthreads SLOTDEFINED SMALLBUF SOABI @@ -190,8 +194,10 @@ subparams subscr sval swappedbytes +sysdict templatelib testconsole +threadstate ticketer tmptype tok_oldval diff --git a/.cspell.dict/rust-more.txt b/.cspell.dict/rust-more.txt index c3ebd61833a..c4457723c6c 100644 --- a/.cspell.dict/rust-more.txt +++ b/.cspell.dict/rust-more.txt @@ -5,7 +5,9 @@ biguint bindgen bitand bitflags +bitflagset bitor +bitvec bitxor bstr byteorder @@ -58,6 +60,7 @@ powi prepended punct replacen +retag rmatch rposition rsplitn @@ -89,5 +92,3 @@ widestring winapi winresource winsock -bitvec -Bitvec diff --git a/.cspell.json b/.cspell.json index e2b1d86aaeb..07fe948c5bf 100644 --- a/.cspell.json +++ b/.cspell.json @@ -152,11 +152,6 @@ "IFEXEC", // "stat" "FIRMLINK", - // CPython internal names - "PYTHONUTF", - "sysdict", - "settraceallthreads", - "setprofileallthreads" ], // flagWords - list of words to be always considered incorrect "flagWords": [ diff --git a/.github/actions/install-linux-deps/action.yml b/.github/actions/install-linux-deps/action.yml new file mode 100644 index 00000000000..7900060fb29 --- /dev/null +++ b/.github/actions/install-linux-deps/action.yml @@ -0,0 +1,49 @@ +# This action installs a few dependencies necessary to build RustPython on Linux. +# It can be configured depending on which libraries are needed: +# +# ``` +# - uses: ./.github/actions/install-linux-deps +# with: +# gcc-multilib: true +# musl-tools: false +# ``` +# +# See the `inputs` section for all options and their defaults. Note that you must checkout the +# repository before you can use this action. +# +# This action will only install dependencies when the current operating system is Linux. It will do +# nothing on any other OS (macOS, Windows). + +name: Install Linux dependencies +description: Installs the dependencies necessary to build RustPython on Linux. +inputs: + gcc-multilib: + description: Install gcc-multilib (gcc-multilib) + required: false + default: "false" + musl-tools: + description: Install musl-tools (musl-tools) + required: false + default: "false" + gcc-aarch64-linux-gnu: + description: Install gcc-aarch64-linux-gnu (gcc-aarch64-linux-gnu) + required: false + default: "false" + clang: + description: Install clang (clang) + required: false + default: "false" +runs: + using: composite + steps: + - name: Install Linux dependencies + shell: bash + if: ${{ runner.os == 'Linux' }} + run: > + sudo apt-get update + + sudo apt-get install --no-install-recommends + ${{ fromJSON(inputs.gcc-multilib) && 'gcc-multilib' || '' }} + ${{ fromJSON(inputs.musl-tools) && 'musl-tools' || '' }} + ${{ fromJSON(inputs.clang) && 'clang' || '' }} + ${{ fromJSON(inputs.gcc-aarch64-linux-gnu) && 'gcc-aarch64-linux-gnu linux-libc-dev-arm64-cross libc6-dev-arm64-cross' || '' }} diff --git a/.github/actions/install-macos-deps/action.yml b/.github/actions/install-macos-deps/action.yml new file mode 100644 index 00000000000..46abef197a4 --- /dev/null +++ b/.github/actions/install-macos-deps/action.yml @@ -0,0 +1,47 @@ +# This action installs a few dependencies necessary to build RustPython on macOS. By default it installs +# autoconf, automake and libtool, but can be configured depending on which libraries are needed: +# +# ``` +# - uses: ./.github/actions/install-macos-deps +# with: +# openssl: true +# libtool: false +# ``` +# +# See the `inputs` section for all options and their defaults. Note that you must checkout the +# repository before you can use this action. +# +# This action will only install dependencies when the current operating system is macOS. It will do +# nothing on any other OS (Linux, Windows). + +name: Install macOS dependencies +description: Installs the dependencies necessary to build RustPython on macOS. +inputs: + autoconf: + description: Install autoconf (autoconf) + required: false + default: "true" + automake: + description: Install automake (automake) + required: false + default: "true" + libtool: + description: Install libtool (libtool) + required: false + default: "true" + openssl: + description: Install openssl (openssl@3) + required: false + default: "false" +runs: + using: composite + steps: + - name: Install macOS dependencies + shell: bash + if: ${{ runner.os == 'macOS' }} + run: > + brew install + ${{ fromJSON(inputs.autoconf) && 'autoconf' || '' }} + ${{ fromJSON(inputs.automake) && 'automake' || '' }} + ${{ fromJSON(inputs.libtool) && 'libtool' || '' }} + ${{ fromJSON(inputs.openssl) && 'openssl@3' || '' }} diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index a58490666c7..15b4997cfcf 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -129,25 +129,20 @@ jobs: os: [macos-latest, ubuntu-latest, windows-2025] fail-fast: false steps: - - uses: actions/checkout@v6.0.2 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable with: components: clippy - uses: Swatinem/rust-cache@v2 - - name: Set up the Mac environment - run: brew install autoconf automake libtool - if: runner.os == 'macOS' + - name: Install macOS dependencies + uses: ./.github/actions/install-macos-deps - name: run clippy run: cargo clippy ${{ env.CARGO_ARGS }} --workspace --all-targets ${{ env.WORKSPACE_EXCLUDES }} -- -Dwarnings - name: run rust tests run: cargo test --workspace ${{ env.WORKSPACE_EXCLUDES }} --verbose --features threading ${{ env.CARGO_ARGS }} - if: runner.os != 'macOS' - - name: run rust tests - run: cargo test --workspace ${{ env.WORKSPACE_EXCLUDES }} --exclude rustpython-jit --verbose --features threading ${{ env.CARGO_ARGS }} - if: runner.os == 'macOS' - name: check compilation without threading run: cargo check ${{ env.CARGO_ARGS }} @@ -189,94 +184,58 @@ jobs: PYTHONPATH: scripts if: runner.os == 'Linux' - - name: prepare Intel MacOS build - uses: dtolnay/rust-toolchain@stable - with: - target: x86_64-apple-darwin - if: runner.os == 'macOS' - - name: Check compilation for Intel MacOS - run: cargo check --target x86_64-apple-darwin - if: runner.os == 'macOS' - - name: prepare iOS build - uses: dtolnay/rust-toolchain@stable - with: - target: aarch64-apple-ios - if: runner.os == 'macOS' - - name: Check compilation for iOS - run: cargo check --target aarch64-apple-ios ${{ env.CARGO_ARGS_NO_SSL }} - if: runner.os == 'macOS' - - exotic_targets: + cargo_check: if: ${{ !contains(github.event.pull_request.labels.*.name, 'skip:ci') }} name: Ensure compilation on various targets - runs-on: ubuntu-latest - timeout-minutes: 30 + runs-on: ${{ matrix.os }} + strategy: + matrix: + include: + - os: ubuntu-latest + targets: + - aarch64-linux-android + - i686-unknown-linux-gnu + - i686-unknown-linux-musl + - wasm32-wasip2 + - x86_64-unknown-freebsd + dependencies: + gcc-multilib: true + musl-tools: true + - os: ubuntu-latest + targets: + - aarch64-unknown-linux-gnu + dependencies: + gcc-aarch64-linux-gnu: true # conflict with `gcc-multilib` + - os: macos-latest + targets: + - aarch64-apple-ios + - x86_64-apple-darwin + fail-fast: false steps: - - uses: actions/checkout@v6.0.2 - - uses: dtolnay/rust-toolchain@stable + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: - target: i686-unknown-linux-gnu + persist-credentials: false - - name: Install gcc-multilib and musl-tools - run: sudo apt-get update && sudo apt-get install gcc-multilib musl-tools - - name: Check compilation for x86 32bit - run: cargo check --target i686-unknown-linux-gnu ${{ env.CARGO_ARGS_NO_SSL }} + - uses: Swatinem/rust-cache@v2 + with: + prefix-key: v0-rust-${{ join(matrix.targets, '-') }} + + - name: Install dependencies + uses: ./.github/actions/install-linux-deps + with: ${{ matrix.dependencies || fromJSON('{}') }} - uses: dtolnay/rust-toolchain@stable with: - target: aarch64-linux-android + targets: ${{ join(matrix.targets, ',') }} - name: Setup Android NDK + if: ${{ contains(matrix.targets, 'aarch64-linux-android') }} id: setup-ndk uses: nttld/setup-ndk@v1 with: ndk-version: r27 add-to-path: true - - name: Check compilation for android - run: cargo check --target aarch64-linux-android ${{ env.CARGO_ARGS_NO_SSL }} - env: - CC_aarch64_linux_android: ${{ steps.setup-ndk.outputs.ndk-path }}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android24-clang - AR_aarch64_linux_android: ${{ steps.setup-ndk.outputs.ndk-path }}/toolchains/llvm/prebuilt/linux-x86_64/bin/llvm-ar - CARGO_TARGET_AARCH64_LINUX_ANDROID_LINKER: ${{ steps.setup-ndk.outputs.ndk-path }}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android24-clang - - - uses: dtolnay/rust-toolchain@stable - with: - target: aarch64-unknown-linux-gnu - - - name: Install gcc-aarch64-linux-gnu - run: sudo apt install gcc-aarch64-linux-gnu - - name: Check compilation for aarch64 linux gnu - run: cargo check --target aarch64-unknown-linux-gnu ${{ env.CARGO_ARGS_NO_SSL }} - - - uses: dtolnay/rust-toolchain@stable - with: - target: i686-unknown-linux-musl - - - name: Check compilation for musl - run: cargo check --target i686-unknown-linux-musl ${{ env.CARGO_ARGS_NO_SSL }} - - - uses: dtolnay/rust-toolchain@stable - with: - target: x86_64-unknown-freebsd - - - name: Check compilation for freebsd - run: cargo check --target x86_64-unknown-freebsd ${{ env.CARGO_ARGS_NO_SSL }} - - - uses: dtolnay/rust-toolchain@stable - with: - target: x86_64-unknown-freebsd - - - name: Check compilation for freeBSD - run: cargo check --target x86_64-unknown-freebsd ${{ env.CARGO_ARGS_NO_SSL }} - - - uses: dtolnay/rust-toolchain@stable - with: - target: wasm32-wasip2 - - - name: Check compilation for wasip2 - run: cargo check --target wasm32-wasip2 ${{ env.CARGO_ARGS_NO_SSL }} - # - name: Prepare repository for redox compilation # run: bash scripts/redox/uncomment-cargo.sh # - name: Check compilation for Redox @@ -285,6 +244,19 @@ jobs: # command: check # args: --ignore-rust-version + - name: Check compilation + run: | + for target in ${{ join(matrix.targets, ' ') }} + do + echo "::group::${target}" + cargo check --target $target ${{ env.CARGO_ARGS_NO_SSL }} + echo "::endgroup::" + done + env: + CC_aarch64_linux_android: ${{ steps.setup-ndk.outputs.ndk-path }}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android24-clang + AR_aarch64_linux_android: ${{ steps.setup-ndk.outputs.ndk-path }}/toolchains/llvm/prebuilt/linux-x86_64/bin/llvm-ar + CARGO_TARGET_AARCH64_LINUX_ANDROID_LINKER: ${{ steps.setup-ndk.outputs.ndk-path }}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android24-clang + snippets_cpython: if: ${{ !contains(github.event.pull_request.labels.*.name, 'skip:ci') }} env: @@ -293,27 +265,27 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [macos-latest, ubuntu-latest, windows-2025] + os: + - macos-latest + - ubuntu-latest + - windows-2025 fail-fast: false steps: - - uses: actions/checkout@v6.0.2 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 - uses: actions/setup-python@v6.2.0 with: python-version: ${{ env.PYTHON_VERSION }} - - name: Set up the Mac environment - run: brew install autoconf automake libtool openssl@3 - if: runner.os == 'macOS' - - name: build rustpython - run: cargo build --release --verbose --features=threading ${{ env.CARGO_ARGS }} - if: runner.os == 'macOS' - - name: build rustpython - run: cargo build --release --verbose --features=threading ${{ env.CARGO_ARGS }},jit - if: runner.os != 'macOS' - - uses: actions/setup-python@v6.2.0 + + - name: Install macOS dependencies + uses: ./.github/actions/install-macos-deps with: - python-version: ${{ env.PYTHON_VERSION }} + openssl: true + + - name: build rustpython + run: cargo build --release --verbose --features=threading,jit ${{ env.CARGO_ARGS }} + - name: run snippets run: python -m pip install -r requirements.txt && pytest -v working-directory: ./extra_tests @@ -445,20 +417,16 @@ jobs: run: | target/release/rustpython -m venv testvenv testvenv/bin/rustpython -m pip install wheel - - if: runner.os != 'macOS' - name: Check whats_left is not broken - shell: bash - run: python -I scripts/whats_left.py --no-default-features --features "$(sed -e 's/--[^ ]*//g' <<< "${{ env.CARGO_ARGS }}" | tr -d '[:space:]'),threading,jit" - - if: runner.os == 'macOS' # TODO fix jit on macOS - name: Check whats_left is not broken (macOS) + + - name: Check whats_left is not broken shell: bash - run: python -I scripts/whats_left.py --no-default-features --features "$(sed -e 's/--[^ ]*//g' <<< "${{ env.CARGO_ARGS }}" | tr -d '[:space:]'),threading" # no jit on macOS for now + run: python -I scripts/whats_left.py ${{ env.CARGO_ARGS }} --features jit lint: name: Lint Rust & Python code runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6.0.2 + - uses: actions/checkout@v6 - uses: actions/setup-python@v6.2.0 with: python-version: ${{ env.PYTHON_VERSION }} @@ -486,7 +454,7 @@ jobs: - name: Install ruff uses: astral-sh/ruff-action@4919ec5cf1f49eff0871dbcea0da843445b837e6 # v3.6.1 with: - version: "0.15.4" + version: "0.15.5" args: "--version" - run: ruff check --diff @@ -514,9 +482,9 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 30 env: - NIGHTLY_CHANNEL: nightly-2026-02-11 # https://github.com/rust-lang/miri/issues/4855 + NIGHTLY_CHANNEL: nightly steps: - - uses: actions/checkout@v6.0.2 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@master with: @@ -538,7 +506,7 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 30 steps: - - uses: actions/checkout@v6.0.2 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 @@ -601,7 +569,7 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 30 steps: - - uses: actions/checkout@v6.0.2 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable with: target: wasm32-wasip1 @@ -609,8 +577,12 @@ jobs: - uses: Swatinem/rust-cache@v2 - name: Setup Wasmer uses: wasmerio/setup-wasmer@v3 + - name: Install clang - run: sudo apt-get update && sudo apt-get install clang -y + uses: ./.github/actions/install-linux-deps + with: + clang: true + - name: build rustpython run: cargo build --release --target wasm32-wasip1 --features freeze-stdlib,stdlib --verbose - name: run snippets diff --git a/.github/workflows/cron-ci.yaml b/.github/workflows/cron-ci.yaml index f451984fb53..64a7d5c88e5 100644 --- a/.github/workflows/cron-ci.yaml +++ b/.github/workflows/cron-ci.yaml @@ -24,7 +24,7 @@ jobs: # Disable this scheduled job when running on a fork. if: ${{ github.repository == 'RustPython/RustPython' || github.event_name != 'schedule' }} steps: - - uses: actions/checkout@v6.0.2 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - uses: taiki-e/install-action@cargo-llvm-cov - uses: actions/setup-python@v6.2.0 @@ -53,7 +53,7 @@ jobs: # Disable this scheduled job when running on a fork. if: ${{ github.repository == 'RustPython/RustPython' || github.event_name != 'schedule' }} steps: - - uses: actions/checkout@v6.0.2 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - name: build rustpython run: cargo build --release --verbose @@ -85,7 +85,7 @@ jobs: # Disable this scheduled job when running on a fork. if: ${{ github.repository == 'RustPython/RustPython' || github.event_name != 'schedule' }} steps: - - uses: actions/checkout@v6.0.2 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - uses: actions/setup-python@v6.2.0 with: @@ -143,7 +143,7 @@ jobs: # Disable this scheduled job when running on a fork. if: ${{ github.repository == 'RustPython/RustPython' || github.event_name != 'schedule' }} steps: - - uses: actions/checkout@v6.0.2 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - uses: actions/setup-python@v6.2.0 with: diff --git a/.github/workflows/lib-deps-check.yaml b/.github/workflows/lib-deps-check.yaml index 550ba2f2529..4eed6b77b16 100644 --- a/.github/workflows/lib-deps-check.yaml +++ b/.github/workflows/lib-deps-check.yaml @@ -21,7 +21,7 @@ jobs: timeout-minutes: 10 steps: - name: Checkout base branch - uses: actions/checkout@v6.0.2 + uses: actions/checkout@v6 with: # Use base branch for scripts (security: don't run PR code with elevated permissions) ref: ${{ github.event.pull_request.base.ref }} diff --git a/.github/workflows/pr-auto-commit.yaml b/.github/workflows/pr-auto-commit.yaml deleted file mode 100644 index ceaa78ba28b..00000000000 --- a/.github/workflows/pr-auto-commit.yaml +++ /dev/null @@ -1,122 +0,0 @@ -name: Auto-format PR - -# This workflow triggers when a PR is opened/updated -on: - pull_request_target: - types: [opened, synchronize, reopened] - branches: - - main - - release - -concurrency: - group: auto-format-${{ github.event.pull_request.number }} - cancel-in-progress: true - -jobs: - auto_format: - permissions: - contents: write - pull-requests: write - runs-on: ubuntu-latest - timeout-minutes: 60 - steps: - - name: Checkout PR branch - uses: actions/checkout@v6.0.2 - with: - ref: ${{ github.event.pull_request.head.sha }} - repository: ${{ github.event.pull_request.head.repo.full_name }} - token: ${{ secrets.AUTO_COMMIT_PAT }} - fetch-depth: 0 - - - name: Setup Rust - uses: dtolnay/rust-toolchain@stable - with: - components: rustfmt - - - name: Configure git - run: | - git config user.name "github-actions[bot]" - git config user.email "github-actions[bot]@users.noreply.github.com" - echo "" > /tmp/committed_commands.txt - - - name: Run cargo fmt - run: | - echo "Running cargo fmt --all on PR #${{ github.event.pull_request.number }}" - cargo fmt --all - if [ -n "$(git status --porcelain)" ]; then - git add -u - git commit -m "Auto-format: cargo fmt --all" - echo "- \`cargo fmt --all\`" >> /tmp/committed_commands.txt - fi - - - name: Install ruff - uses: astral-sh/ruff-action@4919ec5cf1f49eff0871dbcea0da843445b837e6 # v3.6.1 - with: - version: "0.15.4" - args: "--version" - - - name: Run ruff format - run: | - ruff format - if [ -n "$(git status --porcelain)" ]; then - git add -u - git commit -m "Auto-format: ruff format" - echo "- \`ruff format\`" >> /tmp/committed_commands.txt - fi - - - name: Run ruff check import sorting - run: | - ruff check --select I --fix - if [ -n "$(git status --porcelain)" ]; then - git add -u - git commit -m "Auto-format: ruff check --select I --fix" - echo "- \`ruff check --select I --fix\`" >> /tmp/committed_commands.txt - fi - - - name: Run generate_opcode_metadata.py - run: | - python scripts/generate_opcode_metadata.py - if [ -n "$(git status --porcelain)" ]; then - git add -u - git commit -m "Auto-generate: generate_opcode_metadata.py" - echo "- \`python scripts/generate_opcode_metadata.py\`" >> /tmp/committed_commands.txt - fi - - - name: Check for changes - id: check-changes - run: | - if [ "$(git rev-parse HEAD)" != "${{ github.event.pull_request.head.sha }}" ]; then - echo "has_changes=true" >> $GITHUB_OUTPUT - else - echo "has_changes=false" >> $GITHUB_OUTPUT - fi - - - name: Push formatting changes - if: steps.check-changes.outputs.has_changes == 'true' - env: - HEAD_REF: ${{ github.event.pull_request.head.ref }} - run: | - git push origin "HEAD:${HEAD_REF}" - - - name: Read committed commands - id: committed-commands - if: steps.check-changes.outputs.has_changes == 'true' - run: | - echo "list<> $GITHUB_OUTPUT - cat /tmp/committed_commands.txt >> $GITHUB_OUTPUT - echo "EOF" >> $GITHUB_OUTPUT - - - name: Comment on PR - if: steps.check-changes.outputs.has_changes == 'true' - uses: marocchino/sticky-pull-request-comment@v2 - with: - number: ${{ github.event.pull_request.number }} - message: | - **Code has been automatically formatted** - - The code in this PR has been formatted using: - ${{ steps.committed-commands.outputs.list }} - Please pull the latest changes before pushing again: - ```bash - git pull origin ${{ github.event.pull_request.head.ref }} - ``` diff --git a/.github/workflows/pr-format.yaml b/.github/workflows/pr-format.yaml new file mode 100644 index 00000000000..68cfac95991 --- /dev/null +++ b/.github/workflows/pr-format.yaml @@ -0,0 +1,56 @@ +name: Format Check + +# This workflow triggers when a PR is opened/updated +# Posts inline suggestion comments instead of auto-committing +on: + pull_request: + types: [opened, synchronize, reopened] + branches: + - main + - release + +concurrency: + group: format-check-${{ github.event.pull_request.number }} + cancel-in-progress: true + +jobs: + format_check: + permissions: + contents: read + pull-requests: write + runs-on: ubuntu-latest + timeout-minutes: 60 + steps: + - name: Checkout PR branch + uses: actions/checkout@v6 + + - name: Setup Rust + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - name: Run cargo fmt + run: cargo fmt --all + + - name: Install ruff + uses: astral-sh/ruff-action@4919ec5cf1f49eff0871dbcea0da843445b837e6 # v3.6.1 + with: + version: "0.15.4" + args: "--version" + + - name: Run ruff format + run: ruff format + + - name: Run ruff check import sorting + run: ruff check --select I --fix + + - name: Run generate_opcode_metadata.py + run: python scripts/generate_opcode_metadata.py + + - name: Post formatting suggestions + uses: reviewdog/action-suggester@v1 + with: + tool_name: auto-format + github_token: ${{ secrets.GITHUB_TOKEN }} + level: warning + filter_mode: diff_context diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ab5f6e230f4..d640ac87a3b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -52,7 +52,7 @@ jobs: # target: aarch64-pc-windows-msvc fail-fast: false steps: - - uses: actions/checkout@v6.0.2 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable - uses: cargo-bins/cargo-binstall@main @@ -88,7 +88,7 @@ jobs: # Disable this scheduled job when running on a fork. if: ${{ github.repository == 'RustPython/RustPython' || github.event_name != 'schedule' }} steps: - - uses: actions/checkout@v6.0.2 + - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@stable with: targets: wasm32-wasip1 @@ -139,7 +139,7 @@ jobs: if: ${{ github.repository == 'RustPython/RustPython' || github.event_name != 'schedule' }} needs: [build, build-wasm] steps: - - uses: actions/checkout@v6.0.2 + - uses: actions/checkout@v6 - name: Download Binary Artifacts uses: actions/download-artifact@v8.0.0 diff --git a/.github/workflows/upgrade-pylib.lock.yml b/.github/workflows/upgrade-pylib.lock.yml index 32aa8743ff7..06b4d12b42e 100644 --- a/.github/workflows/upgrade-pylib.lock.yml +++ b/.github/workflows/upgrade-pylib.lock.yml @@ -58,7 +58,7 @@ jobs: comment_repo: "" steps: - name: Setup Scripts - uses: github/gh-aw/actions/setup@88319be75ab1adc60640307a10e5cf04b3deff1e # v0.51.5 + uses: github/gh-aw/actions/setup@f1073c5498ee46fec1530555a7c953445417c69b # v0.56.2 with: destination: /opt/gh-aw/actions - name: Check workflow file timestamps @@ -99,7 +99,7 @@ jobs: secret_verification_result: ${{ steps.validate-secret.outputs.verification_result }} steps: - name: Setup Scripts - uses: github/gh-aw/actions/setup@88319be75ab1adc60640307a10e5cf04b3deff1e # v0.51.5 + uses: github/gh-aw/actions/setup@f1073c5498ee46fec1530555a7c953445417c69b # v0.56.2 with: destination: /opt/gh-aw/actions - name: Checkout repository @@ -804,7 +804,7 @@ jobs: total_count: ${{ steps.missing_tool.outputs.total_count }} steps: - name: Setup Scripts - uses: github/gh-aw/actions/setup@88319be75ab1adc60640307a10e5cf04b3deff1e # v0.51.5 + uses: github/gh-aw/actions/setup@f1073c5498ee46fec1530555a7c953445417c69b # v0.56.2 with: destination: /opt/gh-aw/actions - name: Download agent output artifact @@ -925,7 +925,7 @@ jobs: success: ${{ steps.parse_results.outputs.success }} steps: - name: Setup Scripts - uses: github/gh-aw/actions/setup@88319be75ab1adc60640307a10e5cf04b3deff1e # v0.51.5 + uses: github/gh-aw/actions/setup@f1073c5498ee46fec1530555a7c953445417c69b # v0.56.2 with: destination: /opt/gh-aw/actions - name: Download agent artifacts @@ -1037,7 +1037,7 @@ jobs: process_safe_outputs_temporary_id_map: ${{ steps.process_safe_outputs.outputs.temporary_id_map }} steps: - name: Setup Scripts - uses: github/gh-aw/actions/setup@88319be75ab1adc60640307a10e5cf04b3deff1e # v0.51.5 + uses: github/gh-aw/actions/setup@f1073c5498ee46fec1530555a7c953445417c69b # v0.56.2 with: destination: /opt/gh-aw/actions - name: Download agent output artifact diff --git a/Cargo.lock b/Cargo.lock index e2a2f05a733..6c2c7eee4fa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -249,9 +249,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-lc-fips-sys" -version = "0.13.11" +version = "0.13.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df6ea8e07e2df15b9f09f2ac5ee2977369b06d116f0c4eb5fa4ad443b73c7f53" +checksum = "5ed8cd42adddefbdb8507fb7443fa9b666631078616b78f70ed22117b5c27d90" dependencies = [ "bindgen 0.72.1", "cc", @@ -349,6 +349,18 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" +[[package]] +name = "bitflagset" +version = "0.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64b6ee310aa7af14142c8c9121775774ff601ae055ed98ba7fac96098bcde1b9" +dependencies = [ + "num-integer", + "num-traits", + "radium", + "ref-cast", +] + [[package]] name = "blake2" version = "0.10.6" @@ -1527,9 +1539,9 @@ dependencies = [ [[package]] name = "insta" -version = "1.46.1" +version = "1.46.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "248b42847813a1550dafd15296fd9748c651d0c32194559dbc05d804d54b21e8" +checksum = "e82db8c87c7f1ccecb34ce0c24399b8a73081427f3c7c50a5d597925356115e4" dependencies = [ "console", "once_cell", @@ -2258,8 +2270,7 @@ dependencies = [ [[package]] name = "parking_lot_core" version = "0.9.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +source = "git+https://github.com/youknowone/parking_lot?branch=rustpython#4392edbe879acc9c0dd94eda53d2205d3ab912c9" dependencies = [ "cfg-if", "libc", @@ -2770,6 +2781,26 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "regalloc2" version = "0.13.5" @@ -2863,7 +2894,7 @@ dependencies = [ [[package]] name = "ruff_python_ast" version = "0.0.0" -source = "git+https://github.com/astral-sh/ruff.git?rev=f14edd8661e2803254f89265548c7487f47a09f6#f14edd8661e2803254f89265548c7487f47a09f6" +source = "git+https://github.com/astral-sh/ruff.git?rev=5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be#5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be" dependencies = [ "aho-corasick", "bitflags 2.11.0", @@ -2881,7 +2912,7 @@ dependencies = [ [[package]] name = "ruff_python_parser" version = "0.0.0" -source = "git+https://github.com/astral-sh/ruff.git?rev=f14edd8661e2803254f89265548c7487f47a09f6#f14edd8661e2803254f89265548c7487f47a09f6" +source = "git+https://github.com/astral-sh/ruff.git?rev=5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be#5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be" dependencies = [ "bitflags 2.11.0", "bstr", @@ -2901,7 +2932,7 @@ dependencies = [ [[package]] name = "ruff_python_trivia" version = "0.0.0" -source = "git+https://github.com/astral-sh/ruff.git?rev=f14edd8661e2803254f89265548c7487f47a09f6#f14edd8661e2803254f89265548c7487f47a09f6" +source = "git+https://github.com/astral-sh/ruff.git?rev=5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be#5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be" dependencies = [ "itertools 0.14.0", "ruff_source_file", @@ -2912,7 +2943,7 @@ dependencies = [ [[package]] name = "ruff_source_file" version = "0.0.0" -source = "git+https://github.com/astral-sh/ruff.git?rev=f14edd8661e2803254f89265548c7487f47a09f6#f14edd8661e2803254f89265548c7487f47a09f6" +source = "git+https://github.com/astral-sh/ruff.git?rev=5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be#5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be" dependencies = [ "memchr", "ruff_text_size", @@ -2921,7 +2952,7 @@ dependencies = [ [[package]] name = "ruff_text_size" version = "0.0.0" -source = "git+https://github.com/astral-sh/ruff.git?rev=f14edd8661e2803254f89265548c7487f47a09f6#f14edd8661e2803254f89265548c7487f47a09f6" +source = "git+https://github.com/astral-sh/ruff.git?rev=5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be#5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be" dependencies = [ "get-size2", ] @@ -2956,9 +2987,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.36" +version = "0.23.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" dependencies = [ "aws-lc-rs", "once_cell", @@ -3129,6 +3160,7 @@ name = "rustpython-compiler-core" version = "0.4.0" dependencies = [ "bitflags 2.11.0", + "bitflagset", "itertools 0.14.0", "lz4_flex", "malachite-bigint", @@ -3283,6 +3315,10 @@ dependencies = [ "pkcs8", "pymath", "rand_core 0.9.5", + "ruff_python_ast", + "ruff_python_parser", + "ruff_source_file", + "ruff_text_size", "rustix", "rustls", "rustls-native-certs", @@ -3706,12 +3742,12 @@ checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] name = "socket2" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -3762,9 +3798,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.114" +version = "2.0.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" dependencies = [ "proc-macro2", "quote", @@ -4281,9 +4317,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.21.0" +version = "1.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b672338555252d43fd2240c714dc444b8c6fb0a5c5335e65a07bba7742735ddb" +checksum = "a68d3c8f01c0cfa54a75291d83601161799e4a89a39e0929f4b0354d88757a37" dependencies = [ "atomic", "js-sys", diff --git a/Cargo.toml b/Cargo.toml index 664340c23cf..fd95e664085 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -101,6 +101,7 @@ opt-level = 3 lto = "thin" [patch.crates-io] +parking_lot_core = { git = "https://github.com/youknowone/parking_lot", branch = "rustpython" } # REDOX START, Uncomment when you want to compile/check with redoxer # REDOX END @@ -155,17 +156,18 @@ rustpython-sre_engine = { path = "crates/sre_engine", version = "0.4.0" } rustpython-wtf8 = { path = "crates/wtf8", version = "0.4.0" } rustpython-doc = { path = "crates/doc", version = "0.4.0" } -# Ruff tag 0.15.4 is based on commit f14edd8661e2803254f89265548c7487f47a09f6 +# Ruff tag 0.15.5 is based on commit 5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be # at the time of this capture. We use the commit hash to ensure reproducible builds. -ruff_python_parser = { git = "https://github.com/astral-sh/ruff.git", rev = "f14edd8661e2803254f89265548c7487f47a09f6" } -ruff_python_ast = { git = "https://github.com/astral-sh/ruff.git", rev = "f14edd8661e2803254f89265548c7487f47a09f6" } -ruff_text_size = { git = "https://github.com/astral-sh/ruff.git", rev = "f14edd8661e2803254f89265548c7487f47a09f6" } -ruff_source_file = { git = "https://github.com/astral-sh/ruff.git", rev = "f14edd8661e2803254f89265548c7487f47a09f6" } +ruff_python_parser = { git = "https://github.com/astral-sh/ruff.git", rev = "5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be" } +ruff_python_ast = { git = "https://github.com/astral-sh/ruff.git", rev = "5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be" } +ruff_text_size = { git = "https://github.com/astral-sh/ruff.git", rev = "5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be" } +ruff_source_file = { git = "https://github.com/astral-sh/ruff.git", rev = "5e4a3d9c3b381df20f6a52caef0f56ed0ebc74be" } phf = { version = "0.13.1", default-features = false, features = ["macros"]} ahash = "0.8.12" ascii = "1.1" bitflags = "2.11.0" +bitflagset = "0.0.3" bstr = "1" bytes = "1.11.1" cfg-if = "1.0" diff --git a/Lib/locale.py b/Lib/locale.py index db6d0abb26b..dfedc6386cb 100644 --- a/Lib/locale.py +++ b/Lib/locale.py @@ -13,7 +13,6 @@ import sys import encodings import encodings.aliases -import re import _collections_abc from builtins import str as _builtin_str import functools @@ -177,8 +176,7 @@ def _strip_padding(s, amount): amount -= 1 return s[lpos:rpos+1] -_percent_re = re.compile(r'%(?:\((?P.*?)\))?' - r'(?P[-#0-9 +*.hlL]*?)[eEfFgGdiouxXcrs%]') +_percent_re = None def _format(percent, value, grouping=False, monetary=False, *additional): if additional: @@ -217,6 +215,13 @@ def format_string(f, val, grouping=False, monetary=False): Grouping is applied if the third parameter is true. Conversion uses monetary thousands separator and grouping strings if forth parameter monetary is true.""" + global _percent_re + if _percent_re is None: + import re + + _percent_re = re.compile(r'%(?:\((?P.*?)\))?(?P[-#0-9 +*.hlL]*?)[eEfFgGdiouxXcrs%]') + percents = list(_percent_re.finditer(f)) new_f = _percent_re.sub('%s', f) diff --git a/Lib/poplib.py b/Lib/poplib.py new file mode 100644 index 00000000000..4469bff44b4 --- /dev/null +++ b/Lib/poplib.py @@ -0,0 +1,477 @@ +"""A POP3 client class. + +Based on the J. Myers POP3 draft, Jan. 96 +""" + +# Author: David Ascher +# [heavily stealing from nntplib.py] +# Updated: Piers Lauder [Jul '97] +# String method conversion and test jig improvements by ESR, February 2001. +# Added the POP3_SSL class. Methods loosely based on IMAP_SSL. Hector Urtubia Aug 2003 + +# Example (see the test function at the end of this file) + +# Imports + +import errno +import re +import socket +import sys + +try: + import ssl + HAVE_SSL = True +except ImportError: + HAVE_SSL = False + +__all__ = ["POP3","error_proto"] + +# Exception raised when an error or invalid response is received: + +class error_proto(Exception): pass + +# Standard Port +POP3_PORT = 110 + +# POP SSL PORT +POP3_SSL_PORT = 995 + +# Line terminators (we always output CRLF, but accept any of CRLF, LFCR, LF) +CR = b'\r' +LF = b'\n' +CRLF = CR+LF + +# maximal line length when calling readline(). This is to prevent +# reading arbitrary length lines. RFC 1939 limits POP3 line length to +# 512 characters, including CRLF. We have selected 2048 just to be on +# the safe side. +_MAXLINE = 2048 + + +class POP3: + + """This class supports both the minimal and optional command sets. + Arguments can be strings or integers (where appropriate) + (e.g.: retr(1) and retr('1') both work equally well. + + Minimal Command Set: + USER name user(name) + PASS string pass_(string) + STAT stat() + LIST [msg] list(msg = None) + RETR msg retr(msg) + DELE msg dele(msg) + NOOP noop() + RSET rset() + QUIT quit() + + Optional Commands (some servers support these): + RPOP name rpop(name) + APOP name digest apop(name, digest) + TOP msg n top(msg, n) + UIDL [msg] uidl(msg = None) + CAPA capa() + STLS stls() + UTF8 utf8() + + Raises one exception: 'error_proto'. + + Instantiate with: + POP3(hostname, port=110) + + NB: the POP protocol locks the mailbox from user + authorization until QUIT, so be sure to get in, suck + the messages, and quit, each time you access the + mailbox. + + POP is a line-based protocol, which means large mail + messages consume lots of python cycles reading them + line-by-line. + + If it's available on your mail server, use IMAP4 + instead, it doesn't suffer from the two problems + above. + """ + + encoding = 'UTF-8' + + def __init__(self, host, port=POP3_PORT, + timeout=socket._GLOBAL_DEFAULT_TIMEOUT): + self.host = host + self.port = port + self._tls_established = False + sys.audit("poplib.connect", self, host, port) + self.sock = self._create_socket(timeout) + self.file = self.sock.makefile('rb') + self._debugging = 0 + self.welcome = self._getresp() + + def _create_socket(self, timeout): + if timeout is not None and not timeout: + raise ValueError('Non-blocking socket (timeout=0) is not supported') + return socket.create_connection((self.host, self.port), timeout) + + def _putline(self, line): + if self._debugging > 1: print('*put*', repr(line)) + sys.audit("poplib.putline", self, line) + self.sock.sendall(line + CRLF) + + + # Internal: send one command to the server (through _putline()) + + def _putcmd(self, line): + if self._debugging: print('*cmd*', repr(line)) + line = bytes(line, self.encoding) + self._putline(line) + + + # Internal: return one line from the server, stripping CRLF. + # This is where all the CPU time of this module is consumed. + # Raise error_proto('-ERR EOF') if the connection is closed. + + def _getline(self): + line = self.file.readline(_MAXLINE + 1) + if len(line) > _MAXLINE: + raise error_proto('line too long') + + if self._debugging > 1: print('*get*', repr(line)) + if not line: raise error_proto('-ERR EOF') + octets = len(line) + # server can send any combination of CR & LF + # however, 'readline()' returns lines ending in LF + # so only possibilities are ...LF, ...CRLF, CR...LF + if line[-2:] == CRLF: + return line[:-2], octets + if line[:1] == CR: + return line[1:-1], octets + return line[:-1], octets + + + # Internal: get a response from the server. + # Raise 'error_proto' if the response doesn't start with '+'. + + def _getresp(self): + resp, o = self._getline() + if self._debugging > 1: print('*resp*', repr(resp)) + if not resp.startswith(b'+'): + raise error_proto(resp) + return resp + + + # Internal: get a response plus following text from the server. + + def _getlongresp(self): + resp = self._getresp() + list = []; octets = 0 + line, o = self._getline() + while line != b'.': + if line.startswith(b'..'): + o = o-1 + line = line[1:] + octets = octets + o + list.append(line) + line, o = self._getline() + return resp, list, octets + + + # Internal: send a command and get the response + + def _shortcmd(self, line): + self._putcmd(line) + return self._getresp() + + + # Internal: send a command and get the response plus following text + + def _longcmd(self, line): + self._putcmd(line) + return self._getlongresp() + + + # These can be useful: + + def getwelcome(self): + return self.welcome + + + def set_debuglevel(self, level): + self._debugging = level + + + # Here are all the POP commands: + + def user(self, user): + """Send user name, return response + + (should indicate password required). + """ + return self._shortcmd('USER %s' % user) + + + def pass_(self, pswd): + """Send password, return response + + (response includes message count, mailbox size). + + NB: mailbox is locked by server from here to 'quit()' + """ + return self._shortcmd('PASS %s' % pswd) + + + def stat(self): + """Get mailbox status. + + Result is tuple of 2 ints (message count, mailbox size) + """ + retval = self._shortcmd('STAT') + rets = retval.split() + if self._debugging: print('*stat*', repr(rets)) + + # Check if the response has enough elements + # RFC 1939 requires at least 3 elements (+OK, message count, mailbox size) + # but allows additional data after the required fields + if len(rets) < 3: + raise error_proto("Invalid STAT response format") + + try: + numMessages = int(rets[1]) + sizeMessages = int(rets[2]) + except ValueError: + raise error_proto("Invalid STAT response data: non-numeric values") + + return (numMessages, sizeMessages) + + + def list(self, which=None): + """Request listing, return result. + + Result without a message number argument is in form + ['response', ['mesg_num octets', ...], octets]. + + Result when a message number argument is given is a + single response: the "scan listing" for that message. + """ + if which is not None: + return self._shortcmd('LIST %s' % which) + return self._longcmd('LIST') + + + def retr(self, which): + """Retrieve whole message number 'which'. + + Result is in form ['response', ['line', ...], octets]. + """ + return self._longcmd('RETR %s' % which) + + + def dele(self, which): + """Delete message number 'which'. + + Result is 'response'. + """ + return self._shortcmd('DELE %s' % which) + + + def noop(self): + """Does nothing. + + One supposes the response indicates the server is alive. + """ + return self._shortcmd('NOOP') + + + def rset(self): + """Unmark all messages marked for deletion.""" + return self._shortcmd('RSET') + + + def quit(self): + """Signoff: commit changes on server, unlock mailbox, close connection.""" + resp = self._shortcmd('QUIT') + self.close() + return resp + + def close(self): + """Close the connection without assuming anything about it.""" + try: + file = self.file + self.file = None + if file is not None: + file.close() + finally: + sock = self.sock + self.sock = None + if sock is not None: + try: + sock.shutdown(socket.SHUT_RDWR) + except OSError as exc: + # The server might already have closed the connection. + # On Windows, this may result in WSAEINVAL (error 10022): + # An invalid operation was attempted. + if (exc.errno != errno.ENOTCONN + and getattr(exc, 'winerror', 0) != 10022): + raise + finally: + sock.close() + + #__del__ = quit + + + # optional commands: + + def rpop(self, user): + """Send RPOP command to access the mailbox with an alternate user.""" + return self._shortcmd('RPOP %s' % user) + + + timestamp = re.compile(br'\+OK.[^<]*(<.*>)') + + def apop(self, user, password): + """Authorisation + + - only possible if server has supplied a timestamp in initial greeting. + + Args: + user - mailbox user; + password - mailbox password. + + NB: mailbox is locked by server from here to 'quit()' + """ + secret = bytes(password, self.encoding) + m = self.timestamp.match(self.welcome) + if not m: + raise error_proto('-ERR APOP not supported by server') + import hashlib + digest = m.group(1)+secret + digest = hashlib.md5(digest).hexdigest() + return self._shortcmd('APOP %s %s' % (user, digest)) + + + def top(self, which, howmuch): + """Retrieve message header of message number 'which' + and first 'howmuch' lines of message body. + + Result is in form ['response', ['line', ...], octets]. + """ + return self._longcmd('TOP %s %s' % (which, howmuch)) + + + def uidl(self, which=None): + """Return message digest (unique id) list. + + If 'which', result contains unique id for that message + in the form 'response mesgnum uid', otherwise result is + the list ['response', ['mesgnum uid', ...], octets] + """ + if which is not None: + return self._shortcmd('UIDL %s' % which) + return self._longcmd('UIDL') + + + def utf8(self): + """Try to enter UTF-8 mode (see RFC 6856). Returns server response. + """ + return self._shortcmd('UTF8') + + + def capa(self): + """Return server capabilities (RFC 2449) as a dictionary + >>> c=poplib.POP3('localhost') + >>> c.capa() + {'IMPLEMENTATION': ['Cyrus', 'POP3', 'server', 'v2.2.12'], + 'TOP': [], 'LOGIN-DELAY': ['0'], 'AUTH-RESP-CODE': [], + 'EXPIRE': ['NEVER'], 'USER': [], 'STLS': [], 'PIPELINING': [], + 'UIDL': [], 'RESP-CODES': []} + >>> + + Really, according to RFC 2449, the cyrus folks should avoid + having the implementation split into multiple arguments... + """ + def _parsecap(line): + lst = line.decode('ascii').split() + return lst[0], lst[1:] + + caps = {} + try: + resp = self._longcmd('CAPA') + rawcaps = resp[1] + for capline in rawcaps: + capnm, capargs = _parsecap(capline) + caps[capnm] = capargs + except error_proto: + raise error_proto('-ERR CAPA not supported by server') + return caps + + + def stls(self, context=None): + """Start a TLS session on the active connection as specified in RFC 2595. + + context - a ssl.SSLContext + """ + if not HAVE_SSL: + raise error_proto('-ERR TLS support missing') + if self._tls_established: + raise error_proto('-ERR TLS session already established') + caps = self.capa() + if not 'STLS' in caps: + raise error_proto('-ERR STLS not supported by server') + if context is None: + context = ssl._create_stdlib_context() + resp = self._shortcmd('STLS') + self.sock = context.wrap_socket(self.sock, + server_hostname=self.host) + self.file = self.sock.makefile('rb') + self._tls_established = True + return resp + + +if HAVE_SSL: + + class POP3_SSL(POP3): + """POP3 client class over SSL connection + + Instantiate with: POP3_SSL(hostname, port=995, context=None) + + hostname - the hostname of the pop3 over ssl server + port - port number + context - a ssl.SSLContext + + See the methods of the parent class POP3 for more documentation. + """ + + def __init__(self, host, port=POP3_SSL_PORT, + *, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, context=None): + if context is None: + context = ssl._create_stdlib_context() + self.context = context + POP3.__init__(self, host, port, timeout) + + def _create_socket(self, timeout): + sock = POP3._create_socket(self, timeout) + sock = self.context.wrap_socket(sock, + server_hostname=self.host) + return sock + + def stls(self, context=None): + """The method unconditionally raises an exception since the + STLS command doesn't make any sense on an already established + SSL/TLS session. + """ + raise error_proto('-ERR TLS session already established') + + __all__.append("POP3_SSL") + +if __name__ == "__main__": + a = POP3(sys.argv[1]) + print(a.getwelcome()) + a.user(sys.argv[2]) + a.pass_(sys.argv[3]) + a.list() + (numMsgs, totalSize) = a.stat() + for i in range(1, numMsgs + 1): + (header, msg, octets) = a.retr(i) + print("Message %d:" % i) + for line in msg: + print(' ' + line) + print('-----------------------') + a.quit() diff --git a/Lib/test/_test_atexit.py b/Lib/test/_test_atexit.py index db4edd72c51..2e961d6a485 100644 --- a/Lib/test/_test_atexit.py +++ b/Lib/test/_test_atexit.py @@ -47,7 +47,6 @@ def func2(*args, **kwargs): ('func2', (), {}), ('func1', (1, 2), {})]) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_badargs(self): def func(): pass @@ -55,14 +54,12 @@ def func(): # func() has no parameter, but it's called with 2 parameters self.assert_raises_unraisable(TypeError, func, 1 ,2) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_raise(self): def raise_type_error(): raise TypeError self.assert_raises_unraisable(TypeError, raise_type_error) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_raise_unnormalized(self): # bpo-10756: Make sure that an unnormalized exception is handled # properly. @@ -71,7 +68,6 @@ def div_zero(): self.assert_raises_unraisable(ZeroDivisionError, div_zero) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_exit(self): self.assert_raises_unraisable(SystemExit, sys.exit) @@ -122,7 +118,6 @@ def test_bound_methods(self): atexit._run_exitfuncs() self.assertEqual(l, [5]) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_atexit_with_unregistered_function(self): # See bpo-46025 for more info def func(): @@ -140,7 +135,6 @@ def func(): finally: atexit.unregister(func) - @unittest.skip("TODO: RUSTPYTHON; Hangs") def test_eq_unregister_clear(self): # Issue #112127: callback's __eq__ may call unregister or _clear class Evil: @@ -154,7 +148,6 @@ def __eq__(self, other): atexit.unregister(Evil()) atexit._clear() - @unittest.skip("TODO: RUSTPYTHON; Hangs") def test_eq_unregister(self): # Issue #112127: callback's __eq__ may call unregister def f1(): diff --git a/Lib/test/_test_gc_fast_cycles.py b/Lib/test/_test_gc_fast_cycles.py new file mode 100644 index 00000000000..4e2c7d72a02 --- /dev/null +++ b/Lib/test/_test_gc_fast_cycles.py @@ -0,0 +1,48 @@ +# Run by test_gc. +from test import support +import _testinternalcapi +import gc +import unittest + +class IncrementalGCTests(unittest.TestCase): + + # Use small increments to emulate longer running process in a shorter time + @support.gc_threshold(200, 10) + def test_incremental_gc_handles_fast_cycle_creation(self): + + class LinkedList: + + #Use slots to reduce number of implicit objects + __slots__ = "next", "prev", "surprise" + + def __init__(self, next=None, prev=None): + self.next = next + if next is not None: + next.prev = self + self.prev = prev + if prev is not None: + prev.next = self + + def make_ll(depth): + head = LinkedList() + for i in range(depth): + head = LinkedList(head, head.prev) + return head + + head = make_ll(1000) + + assert(gc.isenabled()) + olds = [] + initial_heap_size = _testinternalcapi.get_tracked_heap_size() + for i in range(20_000): + newhead = make_ll(20) + newhead.surprise = head + olds.append(newhead) + if len(olds) == 20: + new_objects = _testinternalcapi.get_tracked_heap_size() - initial_heap_size + self.assertLess(new_objects, 27_000, f"Heap growing. Reached limit after {i} iterations") + del olds[:] + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py index 894cebda57b..35ce70fced2 100644 --- a/Lib/test/_test_multiprocessing.py +++ b/Lib/test/_test_multiprocessing.py @@ -1459,7 +1459,7 @@ def _acquire_release(lock, timeout, l=None, n=1): for _ in range(n): lock.release() - @unittest.skip("TODO: RUSTPYTHON; flaky timeout") + @unittest.skip("TODO: RUSTPYTHON; flaky timeout - thread start latency") def test_repr_rlock(self): if self.TYPE != 'processes': self.skipTest('test not appropriate for {}'.format(self.TYPE)) @@ -4415,7 +4415,6 @@ def test_shared_memory_across_processes(self): sms.close() - @unittest.skip("TODO: RUSTPYTHON; flaky") @unittest.skipIf(os.name != "posix", "not feasible in non-posix platforms") def test_shared_memory_SharedMemoryServer_ignores_sigint(self): # bpo-36368: protect SharedMemoryManager server process from @@ -4440,7 +4439,6 @@ def test_shared_memory_SharedMemoryServer_ignores_sigint(self): smm.shutdown() - @unittest.skip("TODO: RUSTPYTHON: sem_unlink cleanup race causes spurious stderr output") @unittest.skipIf(os.name != "posix", "resource_tracker is posix only") @resource_tracker_format_subtests def test_shared_memory_SharedMemoryManager_reuses_resource_tracker(self): diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index 1a06b426f71..b60c7452f3f 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -2876,7 +2876,6 @@ def test_get_event_loop_after_set_none(self): policy.set_event_loop(None) self.assertRaises(RuntimeError, policy.get_event_loop) - @unittest.expectedFailure # TODO: RUSTPYTHON; - mock.patch doesn't work correctly with threading.current_thread @mock.patch('asyncio.events.threading.current_thread') def test_get_event_loop_thread(self, m_current_thread): diff --git a/Lib/test/test_asyncio/test_unix_events.py b/Lib/test/test_asyncio/test_unix_events.py index 0faf32f79ea..520f5c733c3 100644 --- a/Lib/test/test_asyncio/test_unix_events.py +++ b/Lib/test/test_asyncio/test_unix_events.py @@ -1179,8 +1179,6 @@ async def runner(): wsock.close() -# TODO: RUSTPYTHON, fork() segfaults due to stale parking_lot global state -@unittest.skip("TODO: RUSTPYTHON") @support.requires_fork() class TestFork(unittest.TestCase): diff --git a/Lib/test/test_concurrent_futures/test_process_pool.py b/Lib/test/test_concurrent_futures/test_process_pool.py index ef318dfc7e1..5d4e9677f5c 100644 --- a/Lib/test/test_concurrent_futures/test_process_pool.py +++ b/Lib/test/test_concurrent_futures/test_process_pool.py @@ -85,7 +85,6 @@ def test_traceback(self): self.assertIn('raise RuntimeError(123) # some comment', f1.getvalue()) - @unittest.skip('TODO: RUSTPYTHON flaky EOFError') @hashlib_helper.requires_hashdigest('md5') def test_ressources_gced_in_workers(self): # Ensure that argument for a job are correctly gc-ed after the job diff --git a/Lib/test/test_concurrent_futures/test_wait.py b/Lib/test/test_concurrent_futures/test_wait.py index 818e0d51a2c..6749a690f6c 100644 --- a/Lib/test/test_concurrent_futures/test_wait.py +++ b/Lib/test/test_concurrent_futures/test_wait.py @@ -200,20 +200,5 @@ def future_func(): def setUpModule(): setup_module() -class ProcessPoolForkWaitTest(ProcessPoolForkWaitTest): # TODO: RUSTPYTHON - @unittest.skipIf(sys.platform == 'linux', "TODO: RUSTPYTHON flaky") - def test_first_completed(self): super().test_first_completed() # TODO: RUSTPYTHON - @unittest.skipIf(sys.platform == 'linux', "TODO: RUSTPYTHON Fatal Python error: Segmentation fault") - def test_first_completed_some_already_completed(self): super().test_first_completed_some_already_completed() # TODO: RUSTPYTHON - @unittest.skipIf(sys.platform != 'win32', "TODO: RUSTPYTHON flaky") - def test_first_exception(self): super().test_first_exception() # TODO: RUSTPYTHON - @unittest.skipIf(sys.platform == 'linux', "TODO: RUSTPYTHON flaky") - def test_first_exception_one_already_failed(self): super().test_first_exception_one_already_failed() # TODO: RUSTPYTHON - @unittest.skipIf(sys.platform != 'win32', "TODO: RUSTPYTHON flaky") - def test_first_exception_some_already_complete(self): super().test_first_exception_some_already_complete() # TODO: RUSTPYTHON - @unittest.skipIf(sys.platform == 'linux', "TODO: RUSTPYTHON Fatal Python error: Segmentation fault") - def test_timeout(self): super().test_timeout() # TODO: RUSTPYTHON - - if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 12db84a1209..dfe6b89f1ed 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -3672,7 +3672,6 @@ class A: self.assertEqual(obj.a, 'a') self.assertEqual(obj.b, 'b') - @unittest.expectedFailure # TODO: RUSTPYTHON def test_slots_no_weakref(self): @dataclass(slots=True) class A: @@ -3687,7 +3686,6 @@ class A: with self.assertRaises(AttributeError): a.__weakref__ - @unittest.expectedFailure # TODO: RUSTPYTHON def test_slots_weakref(self): @dataclass(slots=True, weakref_slot=True) class A: @@ -3748,7 +3746,6 @@ def test_weakref_slot_make_dataclass(self): "weakref_slot is True but slots is False"): B = make_dataclass('B', [('a', int),], weakref_slot=True) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_weakref_slot_subclass_weakref_slot(self): @dataclass(slots=True, weakref_slot=True) class Base: @@ -3767,7 +3764,6 @@ class A(Base): a_ref = weakref.ref(a) self.assertIs(a.__weakref__, a_ref) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_weakref_slot_subclass_no_weakref_slot(self): @dataclass(slots=True, weakref_slot=True) class Base: @@ -3785,7 +3781,6 @@ class A(Base): a_ref = weakref.ref(a) self.assertIs(a.__weakref__, a_ref) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_weakref_slot_normal_base_weakref_slot(self): class Base: __slots__ = ('__weakref__',) @@ -3830,7 +3825,6 @@ class B[T2]: self.assertTrue(B.__weakref__) B() - @unittest.expectedFailure # TODO: RUSTPYTHON def test_dataclass_derived_generic_from_base(self): T = typing.TypeVar('T') diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py index c948d156cdb..1f7c5452c4d 100644 --- a/Lib/test/test_descr.py +++ b/Lib/test/test_descr.py @@ -1321,7 +1321,6 @@ class X(object): with self.assertRaisesRegex(AttributeError, "'X' object has no attribute 'a'"): X().a - @unittest.expectedFailure # TODO: RUSTPYTHON def test_slots_special(self): # Testing __dict__ and __weakref__ in __slots__... class D(object): @@ -2294,7 +2293,6 @@ def __contains__(self, value): self.assertIn(i, p10) self.assertNotIn(10, p10) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_weakrefs(self): # Testing weak references... import weakref @@ -3976,7 +3974,6 @@ def __init__(self, x): o = trash(o) del o - @unittest.expectedFailure # TODO: RUSTPYTHON def test_slots_multiple_inheritance(self): # SF bug 575229, multiple inheritance w/ slots dumps core class A(object): diff --git a/Lib/test/test_fork1.py b/Lib/test/test_fork1.py index 4f4a5ee0507..a6523bbc518 100644 --- a/Lib/test/test_fork1.py +++ b/Lib/test/test_fork1.py @@ -19,7 +19,6 @@ class ForkTest(ForkWait): - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: process 44587 exited with code 1, but exit code 42 is expected def test_threaded_import_lock_fork(self): """Check fork() in main thread works while a subthread is doing an import""" import_started = threading.Event() diff --git a/Lib/test/test_format.py b/Lib/test/test_format.py index 5badff612b8..6868c87171d 100644 --- a/Lib/test/test_format.py +++ b/Lib/test/test_format.py @@ -423,7 +423,6 @@ def test_non_ascii(self): self.assertEqual(format(1+2j, "\u2007^8"), "\u2007(1+2j)\u2007") self.assertEqual(format(0j, "\u2007^4"), "\u20070j\u2007") - @unittest.skip("TODO: RUSTPYTHON; formatting does not support locales. See https://github.com/RustPython/RustPython/issues/5181") def test_locale(self): try: oldloc = locale.setlocale(locale.LC_ALL) diff --git a/Lib/test/test_gc.py b/Lib/test/test_gc.py index 3e3092dcae1..879a2875aaa 100644 --- a/Lib/test/test_gc.py +++ b/Lib/test/test_gc.py @@ -236,6 +236,8 @@ def test_function(self): # is 3 because it includes f's code object. self.assertIn(gc.collect(), (2, 3)) + # TODO: RUSTPYTHON - weakref clear ordering differs from 3.15+ + @unittest.expectedFailure def test_function_tp_clear_leaves_consistent_state(self): # https://github.com/python/cpython/issues/91636 code = """if 1: @@ -262,9 +264,11 @@ class Cyclic(tuple): # finalizer. def __del__(self): - # 5. Create a weakref to `func` now. If we had created - # it earlier, it would have been cleared by the - # garbage collector before calling the finalizers. + # 5. Create a weakref to `func` now. In previous + # versions of Python, this would avoid having it + # cleared by the garbage collector before calling + # the finalizers. Now, weakrefs get cleared after + # calling finalizers. self[1].ref = weakref.ref(self[0]) # 6. Drop the global reference to `latefin`. The only @@ -293,16 +297,42 @@ def func(): # which will find `cyc` and `func` as garbage. gc.collect() - # 9. Previously, this would crash because `func_qualname` - # had been NULL-ed out by func_clear(). + # 9. Previously, this would crash because the weakref + # created in the finalizer revealed the function after + # `tp_clear` was called and `func_qualname` + # had been NULL-ed out by func_clear(). Now, we clear + # weakrefs to unreachable objects before calling `tp_clear` + # but after calling finalizers. print(f"{func=}") """ - # We're mostly just checking that this doesn't crash. rc, stdout, stderr = assert_python_ok("-c", code) self.assertEqual(rc, 0) - self.assertRegex(stdout, rb"""\A\s*func=\s*\z""") + # The `func` global is None because the weakref was cleared. + self.assertRegex(stdout, rb"""\A\s*func=None""") self.assertFalse(stderr) + # TODO: RUSTPYTHON - _datetime module not available + @unittest.expectedFailure + def test_datetime_weakref_cycle(self): + # https://github.com/python/cpython/issues/132413 + # If the weakref used by the datetime extension gets cleared by the GC (due to being + # in an unreachable cycle) then datetime functions would crash (get_module_state() + # was returning a NULL pointer). This bug is fixed by clearing weakrefs without + # callbacks *after* running finalizers. + code = """if 1: + import _datetime + class C: + def __del__(self): + print('__del__ called') + _datetime.timedelta(days=1) # crash? + + l = [C()] + l.append(l) + """ + rc, stdout, stderr = assert_python_ok("-c", code) + self.assertEqual(rc, 0) + self.assertEqual(stdout.strip(), b'__del__ called') + @refcount_test def test_frame(self): def f(): @@ -652,9 +682,8 @@ def callback(ignored): gc.collect() self.assertEqual(len(ouch), 2) # else the callbacks didn't run for x in ouch: - # If the callback resurrected one of these guys, the instance - # would be damaged, with an empty __dict__. - self.assertEqual(x, None) + # The weakref should be cleared before executing the callback. + self.assertIsNone(x) def test_bug21435(self): # This is a poor test - its only virtue is that it happened to @@ -821,11 +850,15 @@ def test_get_stats(self): self.assertEqual(len(stats), 3) for st in stats: self.assertIsInstance(st, dict) - self.assertEqual(set(st), - {"collected", "collections", "uncollectable"}) + self.assertEqual( + set(st), + {"collected", "collections", "uncollectable", "candidates", "duration"} + ) self.assertGreaterEqual(st["collected"], 0) self.assertGreaterEqual(st["collections"], 0) self.assertGreaterEqual(st["uncollectable"], 0) + self.assertGreaterEqual(st["candidates"], 0) + self.assertGreaterEqual(st["duration"], 0) # Check that collection counts are incremented correctly if gc.isenabled(): self.addCleanup(gc.enable) @@ -836,11 +869,25 @@ def test_get_stats(self): self.assertEqual(new[0]["collections"], old[0]["collections"] + 1) self.assertEqual(new[1]["collections"], old[1]["collections"]) self.assertEqual(new[2]["collections"], old[2]["collections"]) + self.assertGreater(new[0]["duration"], old[0]["duration"]) + self.assertEqual(new[1]["duration"], old[1]["duration"]) + self.assertEqual(new[2]["duration"], old[2]["duration"]) + for stat in ["collected", "uncollectable", "candidates"]: + self.assertGreaterEqual(new[0][stat], old[0][stat]) + self.assertEqual(new[1][stat], old[1][stat]) + self.assertEqual(new[2][stat], old[2][stat]) gc.collect(2) - new = gc.get_stats() - self.assertEqual(new[0]["collections"], old[0]["collections"] + 1) + old, new = new, gc.get_stats() + self.assertEqual(new[0]["collections"], old[0]["collections"]) self.assertEqual(new[1]["collections"], old[1]["collections"]) self.assertEqual(new[2]["collections"], old[2]["collections"] + 1) + self.assertEqual(new[0]["duration"], old[0]["duration"]) + self.assertEqual(new[1]["duration"], old[1]["duration"]) + self.assertGreater(new[2]["duration"], old[2]["duration"]) + for stat in ["collected", "uncollectable", "candidates"]: + self.assertEqual(new[0][stat], old[0][stat]) + self.assertEqual(new[1][stat], old[1][stat]) + self.assertGreaterEqual(new[2][stat], old[2][stat]) def test_freeze(self): gc.freeze() @@ -1156,6 +1203,37 @@ def test_something(self): """) assert_python_ok("-c", source) + def test_do_not_cleanup_type_subclasses_before_finalization(self): + # See https://github.com/python/cpython/issues/135552 + # If we cleanup weakrefs for tp_subclasses before calling + # the finalizer (__del__) then the line `fail = BaseNode.next.next` + # should fail because we are trying to access a subclass + # attribute. But subclass type cache was not properly invalidated. + code = """ + class BaseNode: + def __del__(self): + BaseNode.next = BaseNode.next.next + fail = BaseNode.next.next + + class Node(BaseNode): + pass + + BaseNode.next = Node() + BaseNode.next.next = Node() + """ + # this test checks garbage collection while interp + # finalization + assert_python_ok("-c", textwrap.dedent(code)) + + code_inside_function = textwrap.dedent(F""" + def test(): + {textwrap.indent(code, ' ')} + + test() + """) + # this test checks regular garbage collection + assert_python_ok("-c", code_inside_function) + @unittest.skipUnless(Py_GIL_DISABLED, "requires free-threaded GC") @unittest.skipIf(_testinternalcapi is None, "requires _testinternalcapi") @@ -1260,9 +1338,11 @@ def test_collect(self): # Check that we got the right info dict for all callbacks for v in self.visit: info = v[2] - self.assertTrue("generation" in info) - self.assertTrue("collected" in info) - self.assertTrue("uncollectable" in info) + self.assertIn("generation", info) + self.assertIn("collected", info) + self.assertIn("uncollectable", info) + self.assertIn("candidates", info) + self.assertIn("duration", info) def test_collect_generation(self): self.preclean() @@ -1450,6 +1530,7 @@ def callback(ignored): self.assertEqual(x, None) @gc_threshold(1000, 0, 0) + @unittest.skipIf(Py_GIL_DISABLED, "requires GC generations or increments") def test_bug1055820d(self): # Corresponds to temp2d.py in the bug report. This is very much like # test_bug1055820c, but uses a __del__ method instead of a weakref diff --git a/Lib/test/test_locale.py b/Lib/test/test_locale.py index 71d03f3a3f9..8e49aa8954e 100644 --- a/Lib/test/test_locale.py +++ b/Lib/test/test_locale.py @@ -1,7 +1,7 @@ from decimal import Decimal -from test.support import verbose, is_android, is_emscripten, is_wasi +from test.support import cpython_only, verbose, is_android, linked_to_musl, os_helper from test.support.warnings_helper import check_warnings -from test.support.import_helper import import_fresh_module +from test.support.import_helper import ensure_lazy_imports, import_fresh_module from unittest import mock import unittest import locale @@ -9,6 +9,11 @@ import sys import codecs +class LazyImportTest(unittest.TestCase): + @cpython_only + def test_lazy_import(self): + ensure_lazy_imports("locale", {"re", "warnings"}) + class BaseLocalizedTest(unittest.TestCase): # @@ -351,10 +356,7 @@ def setUp(self): @unittest.skipIf(sys.platform.startswith('aix'), 'bpo-29972: broken test on AIX') - @unittest.skipIf( - is_emscripten or is_wasi, - "musl libc issue on Emscripten/WASI, bpo-46390" - ) + @unittest.skipIf(linked_to_musl(), "musl libc issue, bpo-46390") @unittest.skipIf(sys.platform.startswith("netbsd"), "gh-124108: NetBSD doesn't support UTF-8 for LC_COLLATE") def test_strcoll_with_diacritic(self): @@ -362,10 +364,7 @@ def test_strcoll_with_diacritic(self): @unittest.skipIf(sys.platform.startswith('aix'), 'bpo-29972: broken test on AIX') - @unittest.skipIf( - is_emscripten or is_wasi, - "musl libc issue on Emscripten/WASI, bpo-46390" - ) + @unittest.skipIf(linked_to_musl(), "musl libc issue, bpo-46390") @unittest.skipIf(sys.platform.startswith("netbsd"), "gh-124108: NetBSD doesn't support UTF-8 for LC_COLLATE") def test_strxfrm_with_diacritic(self): @@ -541,7 +540,6 @@ def test_defaults_UTF8(self): # valid. Furthermore LC_CTYPE=UTF is used by the UTF-8 locale coercing # during interpreter startup (on macOS). import _locale - import os self.assertEqual(locale._parse_localename('UTF-8'), (None, 'UTF-8')) @@ -551,25 +549,14 @@ def test_defaults_UTF8(self): else: orig_getlocale = None - orig_env = {} try: - for key in ('LC_ALL', 'LC_CTYPE', 'LANG', 'LANGUAGE'): - if key in os.environ: - orig_env[key] = os.environ[key] - del os.environ[key] - - os.environ['LC_CTYPE'] = 'UTF-8' - - with check_warnings(('', DeprecationWarning)): - self.assertEqual(locale.getdefaultlocale(), (None, 'UTF-8')) + with os_helper.EnvironmentVarGuard() as env: + env.unset('LC_ALL', 'LC_CTYPE', 'LANG', 'LANGUAGE') + env.set('LC_CTYPE', 'UTF-8') + with check_warnings(('', DeprecationWarning)): + self.assertEqual(locale.getdefaultlocale(), (None, 'UTF-8')) finally: - for k in orig_env: - os.environ[k] = orig_env[k] - - if 'LC_CTYPE' not in orig_env: - del os.environ['LC_CTYPE'] - if orig_getlocale is not None: _locale._getdefaultlocale = orig_getlocale diff --git a/Lib/test/test_multiprocessing_fork/test_manager.py b/Lib/test/test_multiprocessing_fork/test_manager.py index f8d7eddd652..9efbb83bbb7 100644 --- a/Lib/test/test_multiprocessing_fork/test_manager.py +++ b/Lib/test/test_multiprocessing_fork/test_manager.py @@ -3,22 +3,5 @@ install_tests_in_module_dict(globals(), 'fork', only_type="manager") -import sys # TODO: RUSTPYTHON -class WithManagerTestCondition(WithManagerTestCondition): # TODO: RUSTPYTHON - @unittest.skipIf(sys.platform == 'linux', 'TODO: RUSTPYTHON, times out') - def test_notify_all(self): super().test_notify_all() # TODO: RUSTPYTHON - -class WithManagerTestQueue(WithManagerTestQueue): # TODO: RUSTPYTHON - @unittest.skipIf(sys.platform == 'linux', 'TODO: RUSTPYTHON, times out') - def test_fork(self): super().test_fork() # TODO: RUSTPYTHON - -local_globs = globals().copy() # TODO: RUSTPYTHON -for name, base in local_globs.items(): # TODO: RUSTPYTHON - if name.startswith('WithManagerTest') and issubclass(base, unittest.TestCase): # TODO: RUSTPYTHON - base = unittest.skipIf( # TODO: RUSTPYTHON - sys.platform == 'linux', # TODO: RUSTPYTHON - 'TODO: RUSTPYTHON flaky BrokenPipeError, flaky ConnectionRefusedError, flaky ConnectionResetError, flaky EOFError' - )(base) # TODO: RUSTPYTHON - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_multiprocessing_fork/test_misc.py b/Lib/test/test_multiprocessing_fork/test_misc.py index bcf0858258e..891a494020c 100644 --- a/Lib/test/test_multiprocessing_fork/test_misc.py +++ b/Lib/test/test_multiprocessing_fork/test_misc.py @@ -3,24 +3,5 @@ install_tests_in_module_dict(globals(), 'fork', exclude_types=True) -import sys # TODO: RUSTPYTHON -class TestManagerExceptions(TestManagerExceptions): # TODO: RUSTPYTHON - @unittest.skipIf(sys.platform == 'linux', "TODO: RUSTPYTHON flaky") - def test_queue_get(self): super().test_queue_get() # TODO: RUSTPYTHON - -@unittest.skipIf(sys.platform == 'linux', "TODO: RUSTPYTHON flaky") -class TestInitializers(TestInitializers): pass # TODO: RUSTPYTHON - -class TestStartMethod(TestStartMethod): # TODO: RUSTPYTHON - @unittest.skipIf(sys.platform == 'linux', "TODO: RUSTPYTHON flaky") - def test_nested_startmethod(self): super().test_nested_startmethod() # TODO: RUSTPYTHON - -@unittest.skipIf(sys.platform == 'linux', "TODO: RUSTPYTHON flaky") -class TestSyncManagerTypes(TestSyncManagerTypes): pass # TODO: RUSTPYTHON - -class MiscTestCase(MiscTestCase): # TODO: RUSTPYTHON - @unittest.skipIf(sys.platform == 'linux', "TODO: RUSTPYTHON flaky") - def test_forked_thread_not_started(self): super().test_forked_thread_not_started() # TODO: RUSTPYTHON - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_multiprocessing_fork/test_threads.py b/Lib/test/test_multiprocessing_fork/test_threads.py index 1065ebf7fe4..1670e34cb17 100644 --- a/Lib/test/test_multiprocessing_fork/test_threads.py +++ b/Lib/test/test_multiprocessing_fork/test_threads.py @@ -3,14 +3,5 @@ install_tests_in_module_dict(globals(), 'fork', only_type="threads") -import os, sys # TODO: RUSTPYTHON -class WithThreadsTestPool(WithThreadsTestPool): # TODO: RUSTPYTHON - @unittest.skip("TODO: RUSTPYTHON; flaky environment pollution when running rustpython -m test --fail-env-changed due to unknown reason") - def test_terminate(self): super().test_terminate() # TODO: RUSTPYTHON - -class WithThreadsTestManagerRestart(WithThreadsTestManagerRestart): # TODO: RUSTPYTHON - @unittest.skipIf(sys.platform == 'linux', 'TODO: RUSTPYTHON flaky flaky BrokenPipeError, flaky ConnectionRefusedError, flaky ConnectionResetError, flaky EOFError') - def test_rapid_restart(self): super().test_rapid_restart() # TODO: RUSTPYTHON - if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py index d63dc60be31..00bd75bab51 100644 --- a/Lib/test/test_os.py +++ b/Lib/test/test_os.py @@ -5574,7 +5574,6 @@ def test_fork_warns_when_non_python_thread_exists(self): self.assertEqual(err.decode("utf-8"), "") self.assertEqual(out.decode("utf-8"), "") - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: b"can't fork at interpreter shutdown" not found in b"Exception ignored in: \nAttributeError: 'NoneType' object has no attribute 'fork'\n" def test_fork_at_finalization(self): code = """if 1: import atexit diff --git a/Lib/test/test_poplib.py b/Lib/test/test_poplib.py new file mode 100644 index 00000000000..ef2da97f867 --- /dev/null +++ b/Lib/test/test_poplib.py @@ -0,0 +1,571 @@ +"""Test script for poplib module.""" + +# Modified by Giampaolo Rodola' to give poplib.POP3 and poplib.POP3_SSL +# a real test suite + +import poplib +import socket +import os +import errno +import threading + +import unittest +from unittest import TestCase, skipUnless +from test import support as test_support +from test.support import hashlib_helper +from test.support import socket_helper +from test.support import threading_helper +from test.support import asynchat +from test.support import asyncore + + +test_support.requires_working_socket(module=True) + +HOST = socket_helper.HOST +PORT = 0 + +SUPPORTS_SSL = False +if hasattr(poplib, 'POP3_SSL'): + import ssl + + SUPPORTS_SSL = True + CERTFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "certdata", "keycert3.pem") + CAFILE = os.path.join(os.path.dirname(__file__) or os.curdir, "certdata", "pycacert.pem") + +requires_ssl = skipUnless(SUPPORTS_SSL, 'SSL not supported') + +# the dummy data returned by server when LIST and RETR commands are issued +LIST_RESP = b'1 1\r\n2 2\r\n3 3\r\n4 4\r\n5 5\r\n.\r\n' +RETR_RESP = b"""From: postmaster@python.org\ +\r\nContent-Type: text/plain\r\n\ +MIME-Version: 1.0\r\n\ +Subject: Dummy\r\n\ +\r\n\ +line1\r\n\ +line2\r\n\ +line3\r\n\ +.\r\n""" + + +class DummyPOP3Handler(asynchat.async_chat): + + CAPAS = {'UIDL': [], 'IMPLEMENTATION': ['python-testlib-pop-server']} + enable_UTF8 = False + + def __init__(self, conn): + asynchat.async_chat.__init__(self, conn) + self.set_terminator(b"\r\n") + self.in_buffer = [] + self.push('+OK dummy pop3 server ready. ') + self.tls_active = False + self.tls_starting = False + + def collect_incoming_data(self, data): + self.in_buffer.append(data) + + def found_terminator(self): + line = b''.join(self.in_buffer) + line = str(line, 'ISO-8859-1') + self.in_buffer = [] + cmd = line.split(' ')[0].lower() + space = line.find(' ') + if space != -1: + arg = line[space + 1:] + else: + arg = "" + if hasattr(self, 'cmd_' + cmd): + method = getattr(self, 'cmd_' + cmd) + method(arg) + else: + self.push('-ERR unrecognized POP3 command "%s".' %cmd) + + def handle_error(self): + raise + + def push(self, data): + asynchat.async_chat.push(self, data.encode("ISO-8859-1") + b'\r\n') + + def cmd_echo(self, arg): + # sends back the received string (used by the test suite) + self.push(arg) + + def cmd_user(self, arg): + if arg != "guido": + self.push("-ERR no such user") + self.push('+OK password required') + + def cmd_pass(self, arg): + if arg != "python": + self.push("-ERR wrong password") + self.push('+OK 10 messages') + + def cmd_stat(self, arg): + self.push('+OK 10 100') + + def cmd_list(self, arg): + if arg: + self.push('+OK %s %s' % (arg, arg)) + else: + self.push('+OK') + asynchat.async_chat.push(self, LIST_RESP) + + cmd_uidl = cmd_list + + def cmd_retr(self, arg): + self.push('+OK %s bytes' %len(RETR_RESP)) + asynchat.async_chat.push(self, RETR_RESP) + + cmd_top = cmd_retr + + def cmd_dele(self, arg): + self.push('+OK message marked for deletion.') + + def cmd_noop(self, arg): + self.push('+OK done nothing.') + + def cmd_rpop(self, arg): + self.push('+OK done nothing.') + + def cmd_apop(self, arg): + self.push('+OK done nothing.') + + def cmd_quit(self, arg): + self.push('+OK closing.') + self.close_when_done() + + def _get_capas(self): + _capas = dict(self.CAPAS) + if not self.tls_active and SUPPORTS_SSL: + _capas['STLS'] = [] + return _capas + + def cmd_capa(self, arg): + self.push('+OK Capability list follows') + if self._get_capas(): + for cap, params in self._get_capas().items(): + _ln = [cap] + if params: + _ln.extend(params) + self.push(' '.join(_ln)) + self.push('.') + + def cmd_utf8(self, arg): + self.push('+OK I know RFC6856' + if self.enable_UTF8 + else '-ERR What is UTF8?!') + + if SUPPORTS_SSL: + + def cmd_stls(self, arg): + if self.tls_active is False: + self.push('+OK Begin TLS negotiation') + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.load_cert_chain(CERTFILE) + tls_sock = context.wrap_socket(self.socket, + server_side=True, + do_handshake_on_connect=False, + suppress_ragged_eofs=False) + self.del_channel() + self.set_socket(tls_sock) + self.tls_active = True + self.tls_starting = True + self.in_buffer = [] + self._do_tls_handshake() + else: + self.push('-ERR Command not permitted when TLS active') + + def _do_tls_handshake(self): + try: + self.socket.do_handshake() + except ssl.SSLError as err: + if err.args[0] in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE): + return + elif err.args[0] == ssl.SSL_ERROR_EOF: + return self.handle_close() + # TODO: SSLError does not expose alert information + elif ("SSLV3_ALERT_BAD_CERTIFICATE" in err.args[1] or + "SSLV3_ALERT_CERTIFICATE_UNKNOWN" in err.args[1]): + return self.handle_close() + raise + except OSError as err: + if err.args[0] == errno.ECONNABORTED: + return self.handle_close() + else: + self.tls_active = True + self.tls_starting = False + + def handle_read(self): + if self.tls_starting: + self._do_tls_handshake() + else: + try: + asynchat.async_chat.handle_read(self) + except ssl.SSLEOFError: + self.handle_close() + +class DummyPOP3Server(asyncore.dispatcher, threading.Thread): + + handler = DummyPOP3Handler + + def __init__(self, address, af=socket.AF_INET): + threading.Thread.__init__(self) + asyncore.dispatcher.__init__(self) + self.daemon = True + self.create_socket(af, socket.SOCK_STREAM) + self.bind(address) + self.listen(5) + self.active = False + self.active_lock = threading.Lock() + self.host, self.port = self.socket.getsockname()[:2] + self.handler_instance = None + + def start(self): + assert not self.active + self.__flag = threading.Event() + threading.Thread.start(self) + self.__flag.wait() + + def run(self): + self.active = True + self.__flag.set() + try: + while self.active and asyncore.socket_map: + with self.active_lock: + asyncore.loop(timeout=0.1, count=1) + finally: + asyncore.close_all(ignore_all=True) + + def stop(self): + assert self.active + self.active = False + self.join() + + def handle_accepted(self, conn, addr): + self.handler_instance = self.handler(conn) + + def handle_connect(self): + self.close() + handle_read = handle_connect + + def writable(self): + return 0 + + def handle_error(self): + raise + + +class TestPOP3Class(TestCase): + def assertOK(self, resp): + self.assertStartsWith(resp, b"+OK") + + def setUp(self): + self.server = DummyPOP3Server((HOST, PORT)) + self.server.start() + self.client = poplib.POP3(self.server.host, self.server.port, + timeout=test_support.LOOPBACK_TIMEOUT) + + def tearDown(self): + self.client.close() + self.server.stop() + # Explicitly clear the attribute to prevent dangling thread + self.server = None + + def test_getwelcome(self): + self.assertEqual(self.client.getwelcome(), + b'+OK dummy pop3 server ready. ') + + def test_exceptions(self): + self.assertRaises(poplib.error_proto, self.client._shortcmd, 'echo -err') + + def test_user(self): + self.assertOK(self.client.user('guido')) + self.assertRaises(poplib.error_proto, self.client.user, 'invalid') + + def test_pass_(self): + self.assertOK(self.client.pass_('python')) + self.assertRaises(poplib.error_proto, self.client.user, 'invalid') + + def test_stat(self): + self.assertEqual(self.client.stat(), (10, 100)) + + original_shortcmd = self.client._shortcmd + def mock_shortcmd_invalid_format(cmd): + if cmd == 'STAT': + return b'+OK' + return original_shortcmd(cmd) + + self.client._shortcmd = mock_shortcmd_invalid_format + with self.assertRaises(poplib.error_proto): + self.client.stat() + + def mock_shortcmd_invalid_data(cmd): + if cmd == 'STAT': + return b'+OK abc def' + return original_shortcmd(cmd) + + self.client._shortcmd = mock_shortcmd_invalid_data + with self.assertRaises(poplib.error_proto): + self.client.stat() + + def mock_shortcmd_extra_fields(cmd): + if cmd == 'STAT': + return b'+OK 1 2 3 4 5' + return original_shortcmd(cmd) + + self.client._shortcmd = mock_shortcmd_extra_fields + + result = self.client.stat() + self.assertEqual(result, (1, 2)) + + self.client._shortcmd = original_shortcmd + + def test_list(self): + self.assertEqual(self.client.list()[1:], + ([b'1 1', b'2 2', b'3 3', b'4 4', b'5 5'], + 25)) + self.assertEndsWith(self.client.list('1'), b"OK 1 1") + + def test_retr(self): + expected = (b'+OK 116 bytes', + [b'From: postmaster@python.org', b'Content-Type: text/plain', + b'MIME-Version: 1.0', b'Subject: Dummy', + b'', b'line1', b'line2', b'line3'], + 113) + foo = self.client.retr('foo') + self.assertEqual(foo, expected) + + def test_too_long_lines(self): + self.assertRaises(poplib.error_proto, self.client._shortcmd, + 'echo +%s' % ((poplib._MAXLINE + 10) * 'a')) + + def test_dele(self): + self.assertOK(self.client.dele('foo')) + + def test_noop(self): + self.assertOK(self.client.noop()) + + def test_rpop(self): + self.assertOK(self.client.rpop('foo')) + + @hashlib_helper.requires_hashdigest('md5', openssl=True) + def test_apop_normal(self): + self.assertOK(self.client.apop('foo', 'dummypassword')) + + @hashlib_helper.requires_hashdigest('md5', openssl=True) + def test_apop_REDOS(self): + # Replace welcome with very long evil welcome. + # NB The upper bound on welcome length is currently 2048. + # At this length, evil input makes each apop call take + # on the order of milliseconds instead of microseconds. + evil_welcome = b'+OK' + (b'<' * 1000000) + with test_support.swap_attr(self.client, 'welcome', evil_welcome): + # The evil welcome is invalid, so apop should throw. + self.assertRaises(poplib.error_proto, self.client.apop, 'a', 'kb') + + def test_top(self): + expected = (b'+OK 116 bytes', + [b'From: postmaster@python.org', b'Content-Type: text/plain', + b'MIME-Version: 1.0', b'Subject: Dummy', b'', + b'line1', b'line2', b'line3'], + 113) + self.assertEqual(self.client.top(1, 1), expected) + + def test_uidl(self): + self.client.uidl() + self.client.uidl('foo') + + def test_utf8_raises_if_unsupported(self): + self.server.handler.enable_UTF8 = False + self.assertRaises(poplib.error_proto, self.client.utf8) + + def test_utf8(self): + self.server.handler.enable_UTF8 = True + expected = b'+OK I know RFC6856' + result = self.client.utf8() + self.assertEqual(result, expected) + + def test_capa(self): + capa = self.client.capa() + self.assertTrue('IMPLEMENTATION' in capa.keys()) + + def test_quit(self): + resp = self.client.quit() + self.assertTrue(resp) + self.assertIsNone(self.client.sock) + self.assertIsNone(self.client.file) + + @requires_ssl + def test_stls_capa(self): + capa = self.client.capa() + self.assertTrue('STLS' in capa.keys()) + + @requires_ssl + def test_stls(self): + expected = b'+OK Begin TLS negotiation' + resp = self.client.stls() + self.assertEqual(resp, expected) + + @requires_ssl + def test_stls_context(self): + expected = b'+OK Begin TLS negotiation' + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_verify_locations(CAFILE) + self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) + self.assertEqual(ctx.check_hostname, True) + with self.assertRaises(ssl.CertificateError): + resp = self.client.stls(context=ctx) + self.client = poplib.POP3("localhost", self.server.port, + timeout=test_support.LOOPBACK_TIMEOUT) + resp = self.client.stls(context=ctx) + self.assertEqual(resp, expected) + + +if SUPPORTS_SSL: + from test.test_ftplib import SSLConnection + + class DummyPOP3_SSLHandler(SSLConnection, DummyPOP3Handler): + + def __init__(self, conn): + asynchat.async_chat.__init__(self, conn) + self.secure_connection() + self.set_terminator(b"\r\n") + self.in_buffer = [] + self.push('+OK dummy pop3 server ready. ') + self.tls_active = True + self.tls_starting = False + + +@requires_ssl +class TestPOP3_SSLClass(TestPOP3Class): + # repeat previous tests by using poplib.POP3_SSL + + def setUp(self): + self.server = DummyPOP3Server((HOST, PORT)) + self.server.handler = DummyPOP3_SSLHandler + self.server.start() + self.client = poplib.POP3_SSL(self.server.host, self.server.port) + + def test__all__(self): + self.assertIn('POP3_SSL', poplib.__all__) + + def test_context(self): + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + self.client.quit() + self.client = poplib.POP3_SSL(self.server.host, self.server.port, + context=ctx) + self.assertIsInstance(self.client.sock, ssl.SSLSocket) + self.assertIs(self.client.sock.context, ctx) + self.assertStartsWith(self.client.noop(), b'+OK') + + def test_stls(self): + self.assertRaises(poplib.error_proto, self.client.stls) + + test_stls_context = test_stls + + def test_stls_capa(self): + capa = self.client.capa() + self.assertFalse('STLS' in capa.keys()) + + +@requires_ssl +class TestPOP3_TLSClass(TestPOP3Class): + # repeat previous tests by using poplib.POP3.stls() + + def setUp(self): + self.server = DummyPOP3Server((HOST, PORT)) + self.server.start() + self.client = poplib.POP3(self.server.host, self.server.port, + timeout=test_support.LOOPBACK_TIMEOUT) + self.client.stls() + + def tearDown(self): + if self.client.file is not None and self.client.sock is not None: + try: + self.client.quit() + except poplib.error_proto: + # happens in the test_too_long_lines case; the overlong + # response will be treated as response to QUIT and raise + # this exception + self.client.close() + self.server.stop() + # Explicitly clear the attribute to prevent dangling thread + self.server = None + + def test_stls(self): + self.assertRaises(poplib.error_proto, self.client.stls) + + test_stls_context = test_stls + + def test_stls_capa(self): + capa = self.client.capa() + self.assertFalse(b'STLS' in capa.keys()) + + +class TestTimeouts(TestCase): + + def setUp(self): + self.evt = threading.Event() + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.settimeout(60) # Safety net. Look issue 11812 + self.port = socket_helper.bind_port(self.sock) + self.thread = threading.Thread(target=self.server, args=(self.evt, self.sock)) + self.thread.daemon = True + self.thread.start() + self.evt.wait() + + def tearDown(self): + self.thread.join() + # Explicitly clear the attribute to prevent dangling thread + self.thread = None + + def server(self, evt, serv): + serv.listen() + evt.set() + try: + conn, addr = serv.accept() + conn.send(b"+ Hola mundo\n") + conn.close() + except TimeoutError: + pass + finally: + serv.close() + + def testTimeoutDefault(self): + self.assertIsNone(socket.getdefaulttimeout()) + socket.setdefaulttimeout(test_support.LOOPBACK_TIMEOUT) + try: + pop = poplib.POP3(HOST, self.port) + finally: + socket.setdefaulttimeout(None) + self.assertEqual(pop.sock.gettimeout(), test_support.LOOPBACK_TIMEOUT) + pop.close() + + def testTimeoutNone(self): + self.assertIsNone(socket.getdefaulttimeout()) + socket.setdefaulttimeout(30) + try: + pop = poplib.POP3(HOST, self.port, timeout=None) + finally: + socket.setdefaulttimeout(None) + self.assertIsNone(pop.sock.gettimeout()) + pop.close() + + def testTimeoutValue(self): + pop = poplib.POP3(HOST, self.port, timeout=test_support.LOOPBACK_TIMEOUT) + self.assertEqual(pop.sock.gettimeout(), test_support.LOOPBACK_TIMEOUT) + pop.close() + with self.assertRaises(ValueError): + poplib.POP3(HOST, self.port, timeout=0) + + +def setUpModule(): + thread_info = threading_helper.threading_setup() + unittest.addModuleCleanup(threading_helper.threading_cleanup, *thread_info) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_symtable.py b/Lib/test/test_symtable.py index ae93ee8d91f..1653ab4a718 100644 --- a/Lib/test/test_symtable.py +++ b/Lib/test/test_symtable.py @@ -561,7 +561,6 @@ def get_identifiers_recursive(self, st, res): for ch in st.get_children(): self.get_identifiers_recursive(ch, res) - @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: 2 != 1 def test_loopvar_in_only_one_scope(self): # ensure that the loop variable appears only once in the symtable comps = [ diff --git a/Lib/test/test_tabnanny.py b/Lib/test/test_tabnanny.py index 372be9eb8c3..d7a77eb26e4 100644 --- a/Lib/test/test_tabnanny.py +++ b/Lib/test/test_tabnanny.py @@ -316,7 +316,6 @@ def validate_cmd(self, *args, stdout="", stderr="", partial=False, expect_failur self.assertListEqual(out.splitlines(), stdout.splitlines()) self.assertListEqual(err.splitlines(), stderr.splitlines()) - @unittest.expectedFailure # TODO: RUSTPYTHON; Should displays error when errored python file is given. def test_with_errored_file(self): """Should displays error when errored python file is given.""" with TemporaryPyFile(SOURCE_CODES["wrong_indented"]) as file_path: diff --git a/Lib/test/test_threading.py b/Lib/test/test_threading.py index 8db0bbdb949..17693ae093f 100644 --- a/Lib/test/test_threading.py +++ b/Lib/test/test_threading.py @@ -1162,8 +1162,6 @@ def import_threading(): self.assertEqual(out, b'') self.assertEqual(err, b'') - # TODO: RUSTPYTHON - __del__ not called during interpreter finalization (no cyclic GC) - @unittest.expectedFailure def test_start_new_thread_at_finalization(self): code = """if 1: import _thread diff --git a/Lib/test/test_tokenize.py b/Lib/test/test_tokenize.py index 44ef4e24165..394a87c3601 100644 --- a/Lib/test/test_tokenize.py +++ b/Lib/test/test_tokenize.py @@ -1,17 +1,22 @@ -from test import support -from test.support import os_helper -from tokenize import (tokenize, _tokenize, untokenize, NUMBER, NAME, OP, - STRING, ENDMARKER, ENCODING, tok_name, detect_encoding, - open as tokenize_open, Untokenizer, generate_tokens, - NEWLINE) -from io import BytesIO, StringIO +import contextlib +import itertools +import os +import re +import string +import tempfile +import token +import tokenize import unittest +from io import BytesIO, StringIO from textwrap import dedent from unittest import TestCase, mock -from test.test_grammar import (VALID_UNDERSCORE_LITERALS, - INVALID_UNDERSCORE_LITERALS) -import os -import token +from test import support +from test.support import os_helper +from test.support.script_helper import run_test_script, make_script, run_python_until_end +from test.support.numbers import ( + VALID_UNDERSCORE_LITERALS, + INVALID_UNDERSCORE_LITERALS, +) # Converts a source string into a list of textual representation @@ -24,12 +29,12 @@ def stringify_tokens_from_source(token_generator, source_string): missing_trailing_nl = source_string[-1] not in '\r\n' for type, token, start, end, line in token_generator: - if type == ENDMARKER: + if type == tokenize.ENDMARKER: break # Ignore the new line on the last line if the input lacks one - if missing_trailing_nl and type == NEWLINE and end[0] == num_lines: + if missing_trailing_nl and type == tokenize.NEWLINE and end[0] == num_lines: continue - type = tok_name[type] + type = tokenize.tok_name[type] result.append(f" {type:10} {token!r:13} {start} {end}") return result @@ -45,18 +50,37 @@ def check_tokenize(self, s, expected): # Format the tokens in s in a table format. # The ENDMARKER and final NEWLINE are omitted. f = BytesIO(s.encode('utf-8')) - result = stringify_tokens_from_source(tokenize(f.readline), s) + result = stringify_tokens_from_source(tokenize.tokenize(f.readline), s) self.assertEqual(result, [" ENCODING 'utf-8' (0, 0) (0, 0)"] + expected.rstrip().splitlines()) + def test_invalid_readline(self): + def gen(): + yield "sdfosdg" + yield "sdfosdg" + with self.assertRaises(TypeError): + list(tokenize.tokenize(gen().__next__)) + + def gen(): + yield b"sdfosdg" + yield b"sdfosdg" + with self.assertRaises(TypeError): + list(tokenize.generate_tokens(gen().__next__)) + + def gen(): + yield "sdfosdg" + 1/0 + with self.assertRaises(ZeroDivisionError): + list(tokenize.generate_tokens(gen().__next__)) + def test_implicit_newline(self): # Make sure that the tokenizer puts in an implicit NEWLINE # when the input lacks a trailing new line. f = BytesIO("x".encode('utf-8')) - tokens = list(tokenize(f.readline)) - self.assertEqual(tokens[-2].type, NEWLINE) - self.assertEqual(tokens[-1].type, ENDMARKER) + tokens = list(tokenize.tokenize(f.readline)) + self.assertEqual(tokens[-2].type, tokenize.NEWLINE) + self.assertEqual(tokens[-1].type, tokenize.ENDMARKER) def test_basic(self): self.check_tokenize("1 + 1", """\ @@ -83,6 +107,32 @@ def test_basic(self): NEWLINE '\\n' (4, 26) (4, 27) DEDENT '' (5, 0) (5, 0) """) + + self.check_tokenize("if True:\r\n # NL\r\n foo='bar'\r\n\r\n", """\ + NAME 'if' (1, 0) (1, 2) + NAME 'True' (1, 3) (1, 7) + OP ':' (1, 7) (1, 8) + NEWLINE '\\r\\n' (1, 8) (1, 10) + COMMENT '# NL' (2, 4) (2, 8) + NL '\\r\\n' (2, 8) (2, 10) + INDENT ' ' (3, 0) (3, 4) + NAME 'foo' (3, 4) (3, 7) + OP '=' (3, 7) (3, 8) + STRING "\'bar\'" (3, 8) (3, 13) + NEWLINE '\\r\\n' (3, 13) (3, 15) + NL '\\r\\n' (4, 0) (4, 2) + DEDENT '' (5, 0) (5, 0) + """) + + self.check_tokenize("x = 1 + \\\r\n1\r\n", """\ + NAME 'x' (1, 0) (1, 1) + OP '=' (1, 2) (1, 3) + NUMBER '1' (1, 4) (1, 5) + OP '+' (1, 6) (1, 7) + NUMBER '1' (2, 0) (2, 1) + NEWLINE '\\r\\n' (2, 1) (2, 3) + """) + indent_error_file = b"""\ def k(x): x += 2 @@ -91,9 +141,18 @@ def k(x): readline = BytesIO(indent_error_file).readline with self.assertRaisesRegex(IndentationError, "unindent does not match any " - "outer indentation level"): - for tok in tokenize(readline): + "outer indentation level") as e: + for tok in tokenize.tokenize(readline): pass + self.assertEqual(e.exception.lineno, 3) + self.assertEqual(e.exception.filename, '') + self.assertEqual(e.exception.end_lineno, None) + self.assertEqual(e.exception.end_offset, None) + self.assertEqual( + e.exception.msg, + 'unindent does not match any outer indentation level') + self.assertEqual(e.exception.offset, 9) + self.assertEqual(e.exception.text, ' x += 5') def test_int(self): # Ordinary integers and binary operators @@ -177,7 +236,7 @@ def test_long(self): """) def test_float(self): - # Floating point numbers + # Floating-point numbers self.check_tokenize("x = 3.14159", """\ NAME 'x' (1, 0) (1, 1) OP '=' (1, 2) (1, 3) @@ -219,8 +278,8 @@ def test_float(self): def test_underscore_literals(self): def number_token(s): f = BytesIO(s.encode('utf-8')) - for toktype, token, start, end, line in tokenize(f.readline): - if toktype == NUMBER: + for toktype, token, start, end, line in tokenize.tokenize(f.readline): + if toktype == tokenize.NUMBER: return token return 'invalid token' for lit in VALID_UNDERSCORE_LITERALS: @@ -228,7 +287,16 @@ def number_token(s): # this won't work with compound complex inputs continue self.assertEqual(number_token(lit), lit) + # Valid cases with extra underscores in the tokenize module + # See gh-105549 for context + extra_valid_cases = {"0_7", "09_99"} for lit in INVALID_UNDERSCORE_LITERALS: + if lit in extra_valid_cases: + continue + try: + number_token(lit) + except tokenize.TokenError: + continue self.assertNotEqual(number_token(lit), lit) def test_string(self): @@ -380,21 +448,175 @@ def test_string(self): STRING 'rb"\""a\\\\\\nb\\\\\\nc"\""' (1, 0) (3, 4) """) self.check_tokenize('f"abc"', """\ - STRING 'f"abc"' (1, 0) (1, 6) + FSTRING_START 'f"' (1, 0) (1, 2) + FSTRING_MIDDLE 'abc' (1, 2) (1, 5) + FSTRING_END '"' (1, 5) (1, 6) """) self.check_tokenize('fR"a{b}c"', """\ - STRING 'fR"a{b}c"' (1, 0) (1, 9) + FSTRING_START 'fR"' (1, 0) (1, 3) + FSTRING_MIDDLE 'a' (1, 3) (1, 4) + OP '{' (1, 4) (1, 5) + NAME 'b' (1, 5) (1, 6) + OP '}' (1, 6) (1, 7) + FSTRING_MIDDLE 'c' (1, 7) (1, 8) + FSTRING_END '"' (1, 8) (1, 9) + """) + self.check_tokenize('fR"a{{{b!r}}}c"', """\ + FSTRING_START 'fR"' (1, 0) (1, 3) + FSTRING_MIDDLE 'a{' (1, 3) (1, 5) + OP '{' (1, 6) (1, 7) + NAME 'b' (1, 7) (1, 8) + OP '!' (1, 8) (1, 9) + NAME 'r' (1, 9) (1, 10) + OP '}' (1, 10) (1, 11) + FSTRING_MIDDLE '}' (1, 11) (1, 12) + FSTRING_MIDDLE 'c' (1, 13) (1, 14) + FSTRING_END '"' (1, 14) (1, 15) + """) + self.check_tokenize('f"{{{1+1}}}"', """\ + FSTRING_START 'f"' (1, 0) (1, 2) + FSTRING_MIDDLE '{' (1, 2) (1, 3) + OP '{' (1, 4) (1, 5) + NUMBER '1' (1, 5) (1, 6) + OP '+' (1, 6) (1, 7) + NUMBER '1' (1, 7) (1, 8) + OP '}' (1, 8) (1, 9) + FSTRING_MIDDLE '}' (1, 9) (1, 10) + FSTRING_END '"' (1, 11) (1, 12) + """) + self.check_tokenize('f"""{f\'\'\'{f\'{f"{1+1}"}\'}\'\'\'}"""', """\ + FSTRING_START 'f\"""' (1, 0) (1, 4) + OP '{' (1, 4) (1, 5) + FSTRING_START "f'''" (1, 5) (1, 9) + OP '{' (1, 9) (1, 10) + FSTRING_START "f'" (1, 10) (1, 12) + OP '{' (1, 12) (1, 13) + FSTRING_START 'f"' (1, 13) (1, 15) + OP '{' (1, 15) (1, 16) + NUMBER '1' (1, 16) (1, 17) + OP '+' (1, 17) (1, 18) + NUMBER '1' (1, 18) (1, 19) + OP '}' (1, 19) (1, 20) + FSTRING_END '"' (1, 20) (1, 21) + OP '}' (1, 21) (1, 22) + FSTRING_END "'" (1, 22) (1, 23) + OP '}' (1, 23) (1, 24) + FSTRING_END "'''" (1, 24) (1, 27) + OP '}' (1, 27) (1, 28) + FSTRING_END '\"""' (1, 28) (1, 31) + """) + self.check_tokenize('f""" x\nstr(data, encoding={invalid!r})\n"""', """\ + FSTRING_START 'f\"""' (1, 0) (1, 4) + FSTRING_MIDDLE ' x\\nstr(data, encoding=' (1, 4) (2, 19) + OP '{' (2, 19) (2, 20) + NAME 'invalid' (2, 20) (2, 27) + OP '!' (2, 27) (2, 28) + NAME 'r' (2, 28) (2, 29) + OP '}' (2, 29) (2, 30) + FSTRING_MIDDLE ')\\n' (2, 30) (3, 0) + FSTRING_END '\"""' (3, 0) (3, 3) + """) + self.check_tokenize('f"""123456789\nsomething{None}bad"""', """\ + FSTRING_START 'f\"""' (1, 0) (1, 4) + FSTRING_MIDDLE '123456789\\nsomething' (1, 4) (2, 9) + OP '{' (2, 9) (2, 10) + NAME 'None' (2, 10) (2, 14) + OP '}' (2, 14) (2, 15) + FSTRING_MIDDLE 'bad' (2, 15) (2, 18) + FSTRING_END '\"""' (2, 18) (2, 21) """) self.check_tokenize('f"""abc"""', """\ - STRING 'f\"\"\"abc\"\"\"' (1, 0) (1, 10) + FSTRING_START 'f\"""' (1, 0) (1, 4) + FSTRING_MIDDLE 'abc' (1, 4) (1, 7) + FSTRING_END '\"""' (1, 7) (1, 10) """) self.check_tokenize(r'f"abc\ def"', """\ - STRING 'f"abc\\\\\\ndef"' (1, 0) (2, 4) + FSTRING_START 'f"' (1, 0) (1, 2) + FSTRING_MIDDLE 'abc\\\\\\ndef' (1, 2) (2, 3) + FSTRING_END '"' (2, 3) (2, 4) """) self.check_tokenize(r'Rf"abc\ def"', """\ - STRING 'Rf"abc\\\\\\ndef"' (1, 0) (2, 4) + FSTRING_START 'Rf"' (1, 0) (1, 3) + FSTRING_MIDDLE 'abc\\\\\\ndef' (1, 3) (2, 3) + FSTRING_END '"' (2, 3) (2, 4) + """) + self.check_tokenize("f'some words {a+b:.3f} more words {c+d=} final words'", """\ + FSTRING_START "f'" (1, 0) (1, 2) + FSTRING_MIDDLE 'some words ' (1, 2) (1, 13) + OP '{' (1, 13) (1, 14) + NAME 'a' (1, 14) (1, 15) + OP '+' (1, 15) (1, 16) + NAME 'b' (1, 16) (1, 17) + OP ':' (1, 17) (1, 18) + FSTRING_MIDDLE '.3f' (1, 18) (1, 21) + OP '}' (1, 21) (1, 22) + FSTRING_MIDDLE ' more words ' (1, 22) (1, 34) + OP '{' (1, 34) (1, 35) + NAME 'c' (1, 35) (1, 36) + OP '+' (1, 36) (1, 37) + NAME 'd' (1, 37) (1, 38) + OP '=' (1, 38) (1, 39) + OP '}' (1, 39) (1, 40) + FSTRING_MIDDLE ' final words' (1, 40) (1, 52) + FSTRING_END "'" (1, 52) (1, 53) + """) + self.check_tokenize("""\ +f'''{ +3 +=}'''""", """\ + FSTRING_START "f'''" (1, 0) (1, 4) + OP '{' (1, 4) (1, 5) + NL '\\n' (1, 5) (1, 6) + NUMBER '3' (2, 0) (2, 1) + NL '\\n' (2, 1) (2, 2) + OP '=' (3, 0) (3, 1) + OP '}' (3, 1) (3, 2) + FSTRING_END "'''" (3, 2) (3, 5) + """) + self.check_tokenize("""\ +f'''__{ + x:a +}__'''""", """\ + FSTRING_START "f'''" (1, 0) (1, 4) + FSTRING_MIDDLE '__' (1, 4) (1, 6) + OP '{' (1, 6) (1, 7) + NL '\\n' (1, 7) (1, 8) + NAME 'x' (2, 4) (2, 5) + OP ':' (2, 5) (2, 6) + FSTRING_MIDDLE 'a\\n' (2, 6) (3, 0) + OP '}' (3, 0) (3, 1) + FSTRING_MIDDLE '__' (3, 1) (3, 3) + FSTRING_END "'''" (3, 3) (3, 6) + """) + self.check_tokenize("""\ +f'''__{ + x:a + b + c + d +}__'''""", """\ + FSTRING_START "f'''" (1, 0) (1, 4) + FSTRING_MIDDLE '__' (1, 4) (1, 6) + OP '{' (1, 6) (1, 7) + NL '\\n' (1, 7) (1, 8) + NAME 'x' (2, 4) (2, 5) + OP ':' (2, 5) (2, 6) + FSTRING_MIDDLE 'a\\n b\\n c\\n d\\n' (2, 6) (6, 0) + OP '}' (6, 0) (6, 1) + FSTRING_MIDDLE '__' (6, 1) (6, 3) + FSTRING_END "'''" (6, 3) (6, 6) + """) + + self.check_tokenize("""\ + '''Autorzy, którzy tą jednostkę mają wpisani jako AKTUALNA -- czyli + aktualni pracownicy, obecni pracownicy''' +""", """\ + INDENT ' ' (1, 0) (1, 4) + STRING "'''Autorzy, którzy tą jednostkę mają wpisani jako AKTUALNA -- czyli\\n aktualni pracownicy, obecni pracownicy'''" (1, 4) (2, 45) + NEWLINE '\\n' (2, 45) (2, 46) + DEDENT '' (3, 0) (3, 0) """) def test_function(self): @@ -945,29 +1167,95 @@ async def bar(): pass DEDENT '' (7, 0) (7, 0) """) + def test_newline_after_parenthesized_block_with_comment(self): + self.check_tokenize('''\ +[ + # A comment here + 1 +] +''', """\ + OP '[' (1, 0) (1, 1) + NL '\\n' (1, 1) (1, 2) + COMMENT '# A comment here' (2, 4) (2, 20) + NL '\\n' (2, 20) (2, 21) + NUMBER '1' (3, 4) (3, 5) + NL '\\n' (3, 5) (3, 6) + OP ']' (4, 0) (4, 1) + NEWLINE '\\n' (4, 1) (4, 2) + """) + + def test_closing_parenthesis_from_different_line(self): + self.check_tokenize("); x", """\ + OP ')' (1, 0) (1, 1) + OP ';' (1, 1) (1, 2) + NAME 'x' (1, 3) (1, 4) + """) + + def test_multiline_non_ascii_fstring(self): + self.check_tokenize("""\ +a = f''' + Autorzy, którzy tą jednostkę mają wpisani jako AKTUALNA -- czyli'''""", """\ + NAME 'a' (1, 0) (1, 1) + OP '=' (1, 2) (1, 3) + FSTRING_START "f\'\'\'" (1, 4) (1, 8) + FSTRING_MIDDLE '\\n Autorzy, którzy tą jednostkę mają wpisani jako AKTUALNA -- czyli' (1, 8) (2, 68) + FSTRING_END "\'\'\'" (2, 68) (2, 71) + """) + + def test_multiline_non_ascii_fstring_with_expr(self): + self.check_tokenize("""\ +f''' + 🔗 This is a test {test_arg1}🔗 +🔗'''""", """\ + FSTRING_START "f\'\'\'" (1, 0) (1, 4) + FSTRING_MIDDLE '\\n 🔗 This is a test ' (1, 4) (2, 21) + OP '{' (2, 21) (2, 22) + NAME 'test_arg1' (2, 22) (2, 31) + OP '}' (2, 31) (2, 32) + FSTRING_MIDDLE '🔗\\n🔗' (2, 32) (3, 1) + FSTRING_END "\'\'\'" (3, 1) (3, 4) + """) + + # gh-139516, the '\n' is explicit to ensure no trailing whitespace which would invalidate the test + self.check_tokenize('''f"{f(a=lambda: 'à'\n)}"''', """\ + FSTRING_START \'f"\' (1, 0) (1, 2) + OP '{' (1, 2) (1, 3) + NAME 'f' (1, 3) (1, 4) + OP '(' (1, 4) (1, 5) + NAME 'a' (1, 5) (1, 6) + OP '=' (1, 6) (1, 7) + NAME 'lambda' (1, 7) (1, 13) + OP ':' (1, 13) (1, 14) + STRING "\'à\'" (1, 15) (1, 18) + NL '\\n' (1, 18) (1, 19) + OP ')' (2, 0) (2, 1) + OP '}' (2, 1) (2, 2) + FSTRING_END \'"\' (2, 2) (2, 3) + """) + class GenerateTokensTest(TokenizeTest): def check_tokenize(self, s, expected): # Format the tokens in s in a table format. # The ENDMARKER and final NEWLINE are omitted. f = StringIO(s) - result = stringify_tokens_from_source(generate_tokens(f.readline), s) + result = stringify_tokens_from_source(tokenize.generate_tokens(f.readline), s) self.assertEqual(result, expected.rstrip().splitlines()) def decistmt(s): result = [] - g = tokenize(BytesIO(s.encode('utf-8')).readline) # tokenize the string + g = tokenize.tokenize(BytesIO(s.encode('utf-8')).readline) # tokenize the string for toknum, tokval, _, _, _ in g: - if toknum == NUMBER and '.' in tokval: # replace NUMBER tokens + if toknum == tokenize.NUMBER and '.' in tokval: # replace NUMBER tokens result.extend([ - (NAME, 'Decimal'), - (OP, '('), - (STRING, repr(tokval)), - (OP, ')') + (tokenize.NAME, 'Decimal'), + (tokenize.OP, '('), + (tokenize.STRING, repr(tokval)), + (tokenize.OP, ')') ]) else: result.append((toknum, tokval)) - return untokenize(result).decode('utf-8') + return tokenize.untokenize(result).decode('utf-8').strip() class TestMisc(TestCase): @@ -991,6 +1279,13 @@ def test_decistmt(self): self.assertEqual(eval(decistmt(s)), Decimal('-3.217160342717258261933904529E-7')) + def test___all__(self): + expected = token.__all__ + [ + "TokenInfo", "TokenError", "generate_tokens", + "detect_encoding", "untokenize", "open", "tokenize", + ] + self.assertCountEqual(tokenize.__all__, expected) + class TestTokenizerAdheresToPep0263(TestCase): """ @@ -998,8 +1293,9 @@ class TestTokenizerAdheresToPep0263(TestCase): """ def _testFile(self, filename): - path = os.path.join(os.path.dirname(__file__), filename) - TestRoundtrip.check_roundtrip(self, open(path, 'rb')) + path = os.path.join(os.path.dirname(__file__), 'tokenizedata', filename) + with open(path, 'rb') as f: + TestRoundtrip.check_roundtrip(self, f) def test_utf8_coding_cookie_and_no_utf8_bom(self): f = 'tokenize_tests-utf8-coding-cookie-and-no-utf8-bom-sig.txt' @@ -1024,8 +1320,6 @@ def test_utf8_coding_cookie_and_utf8_bom(self): f = 'tokenize_tests-utf8-coding-cookie-and-utf8-bom-sig.txt' self._testFile(f) - # TODO: RUSTPYTHON - @unittest.expectedFailure # "bad_coding.py" and "bad_coding2.py" make the WASM CI fail def test_bad_coding_cookie(self): self.assertRaises(SyntaxError, self._testFile, 'bad_coding.py') self.assertRaises(SyntaxError, self._testFile, 'bad_coding2.py') @@ -1041,33 +1335,18 @@ def readline(): nonlocal first if not first: first = True - return line + yield line else: - return b'' + yield b'' # skip the initial encoding token and the end tokens - tokens = list(_tokenize(readline, encoding='utf-8'))[1:-2] - expected_tokens = [(3, '"ЉЊЈЁЂ"', (1, 0), (1, 7), '"ЉЊЈЁЂ"')] + tokens = list(tokenize._generate_tokens_from_c_tokenizer(readline().__next__, + encoding='utf-8', + extra_tokens=True))[:-2] + expected_tokens = [tokenize.TokenInfo(3, '"ЉЊЈЁЂ"', (1, 0), (1, 7), '"ЉЊЈЁЂ"')] self.assertEqual(tokens, expected_tokens, "bytes not decoded with encoding") - def test__tokenize_does_not_decode_with_encoding_none(self): - literal = '"ЉЊЈЁЂ"' - first = False - def readline(): - nonlocal first - if not first: - first = True - return literal - else: - return b'' - - # skip the end tokens - tokens = list(_tokenize(readline, encoding=None))[:-2] - expected_tokens = [(3, '"ЉЊЈЁЂ"', (1, 0), (1, 7), '"ЉЊЈЁЂ"')] - self.assertEqual(tokens, expected_tokens, - "string not tokenized when encoding is None") - class TestDetectEncoding(TestCase): @@ -1084,24 +1363,63 @@ def readline(): def test_no_bom_no_encoding_cookie(self): lines = ( - b'# something\n', + b'#!/home/\xc3\xa4/bin/python\n', + b'# something \xe2\x82\xac\n', b'print(something)\n', b'do_something(else)\n' ) - encoding, consumed_lines = detect_encoding(self.get_readline(lines)) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) self.assertEqual(encoding, 'utf-8') self.assertEqual(consumed_lines, list(lines[:2])) + def test_no_bom_no_encoding_cookie_first_line_error(self): + lines = ( + b'#!/home/\xa4/bin/python\n\n', + b'print(something)\n', + b'do_something(else)\n' + ) + with self.assertRaises(SyntaxError): + tokenize.detect_encoding(self.get_readline(lines)) + + def test_no_bom_no_encoding_cookie_second_line_error(self): + lines = ( + b'#!/usr/bin/python\n', + b'# something \xe2\n', + b'print(something)\n', + b'do_something(else)\n' + ) + with self.assertRaises(SyntaxError): + tokenize.detect_encoding(self.get_readline(lines)) + def test_bom_no_cookie(self): lines = ( - b'\xef\xbb\xbf# something\n', + b'\xef\xbb\xbf#!/home/\xc3\xa4/bin/python\n', b'print(something)\n', b'do_something(else)\n' ) - encoding, consumed_lines = detect_encoding(self.get_readline(lines)) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) self.assertEqual(encoding, 'utf-8-sig') self.assertEqual(consumed_lines, - [b'# something\n', b'print(something)\n']) + [b'#!/home/\xc3\xa4/bin/python\n', b'print(something)\n']) + + def test_bom_no_cookie_first_line_error(self): + lines = ( + b'\xef\xbb\xbf#!/home/\xa4/bin/python\n', + b'print(something)\n', + b'do_something(else)\n' + ) + with self.assertRaises(SyntaxError): + tokenize.detect_encoding(self.get_readline(lines)) + + def test_bom_no_cookie_second_line_error(self): + lines = ( + b'\xef\xbb\xbf#!/usr/bin/python\n', + b'# something \xe2\n', + b'print(something)\n', + b'do_something(else)\n' + ) + with self.assertRaises(SyntaxError): + tokenize.detect_encoding(self.get_readline(lines)) def test_cookie_first_line_no_bom(self): lines = ( @@ -1109,7 +1427,7 @@ def test_cookie_first_line_no_bom(self): b'print(something)\n', b'do_something(else)\n' ) - encoding, consumed_lines = detect_encoding(self.get_readline(lines)) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) self.assertEqual(encoding, 'iso-8859-1') self.assertEqual(consumed_lines, [b'# -*- coding: latin-1 -*-\n']) @@ -1119,7 +1437,7 @@ def test_matched_bom_and_cookie_first_line(self): b'print(something)\n', b'do_something(else)\n' ) - encoding, consumed_lines = detect_encoding(self.get_readline(lines)) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) self.assertEqual(encoding, 'utf-8-sig') self.assertEqual(consumed_lines, [b'# coding=utf-8\n']) @@ -1130,7 +1448,7 @@ def test_mismatched_bom_and_cookie_first_line_raises_syntaxerror(self): b'do_something(else)\n' ) readline = self.get_readline(lines) - self.assertRaises(SyntaxError, detect_encoding, readline) + self.assertRaises(SyntaxError, tokenize.detect_encoding, readline) def test_cookie_second_line_no_bom(self): lines = ( @@ -1139,7 +1457,7 @@ def test_cookie_second_line_no_bom(self): b'print(something)\n', b'do_something(else)\n' ) - encoding, consumed_lines = detect_encoding(self.get_readline(lines)) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) self.assertEqual(encoding, 'ascii') expected = [b'#! something\n', b'# vim: set fileencoding=ascii :\n'] self.assertEqual(consumed_lines, expected) @@ -1151,7 +1469,7 @@ def test_matched_bom_and_cookie_second_line(self): b'print(something)\n', b'do_something(else)\n' ) - encoding, consumed_lines = detect_encoding(self.get_readline(lines)) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) self.assertEqual(encoding, 'utf-8-sig') self.assertEqual(consumed_lines, [b'#! something\n', b'f# coding=utf-8\n']) @@ -1164,7 +1482,7 @@ def test_mismatched_bom_and_cookie_second_line_raises_syntaxerror(self): b'do_something(else)\n' ) readline = self.get_readline(lines) - self.assertRaises(SyntaxError, detect_encoding, readline) + self.assertRaises(SyntaxError, tokenize.detect_encoding, readline) def test_cookie_second_line_noncommented_first_line(self): lines = ( @@ -1172,21 +1490,65 @@ def test_cookie_second_line_noncommented_first_line(self): b'# vim: set fileencoding=iso8859-15 :\n', b"print('\xe2\x82\xac')\n" ) - encoding, consumed_lines = detect_encoding(self.get_readline(lines)) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) self.assertEqual(encoding, 'utf-8') expected = [b"print('\xc2\xa3')\n"] self.assertEqual(consumed_lines, expected) - def test_cookie_second_line_commented_first_line(self): + def test_first_non_utf8_coding_line(self): lines = ( - b"#print('\xc2\xa3')\n", - b'# vim: set fileencoding=iso8859-15 :\n', - b"print('\xe2\x82\xac')\n" + b'#coding:iso-8859-15 \xa4\n', + b'print(something)\n' ) - encoding, consumed_lines = detect_encoding(self.get_readline(lines)) - self.assertEqual(encoding, 'iso8859-15') - expected = [b"#print('\xc2\xa3')\n", b'# vim: set fileencoding=iso8859-15 :\n'] - self.assertEqual(consumed_lines, expected) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) + self.assertEqual(encoding, 'iso-8859-15') + self.assertEqual(consumed_lines, list(lines[:1])) + + def test_first_utf8_coding_line_error(self): + lines = ( + b'#coding:ascii \xc3\xa4\n', + b'print(something)\n' + ) + with self.assertRaises(SyntaxError): + tokenize.detect_encoding(self.get_readline(lines)) + + def test_second_non_utf8_coding_line(self): + lines = ( + b'#!/usr/bin/python\n', + b'#coding:iso-8859-15 \xa4\n', + b'print(something)\n' + ) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) + self.assertEqual(encoding, 'iso-8859-15') + self.assertEqual(consumed_lines, list(lines[:2])) + + def test_second_utf8_coding_line_error(self): + lines = ( + b'#!/usr/bin/python\n', + b'#coding:ascii \xc3\xa4\n', + b'print(something)\n' + ) + with self.assertRaises(SyntaxError): + tokenize.detect_encoding(self.get_readline(lines)) + + def test_non_utf8_shebang(self): + lines = ( + b'#!/home/\xa4/bin/python\n', + b'#coding:iso-8859-15\n', + b'print(something)\n' + ) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) + self.assertEqual(encoding, 'iso-8859-15') + self.assertEqual(consumed_lines, list(lines[:2])) + + def test_utf8_shebang_error(self): + lines = ( + b'#!/home/\xc3\xa4/bin/python\n', + b'#coding:ascii\n', + b'print(something)\n' + ) + with self.assertRaises(SyntaxError): + tokenize.detect_encoding(self.get_readline(lines)) def test_cookie_second_line_empty_first_line(self): lines = ( @@ -1194,13 +1556,77 @@ def test_cookie_second_line_empty_first_line(self): b'# vim: set fileencoding=iso8859-15 :\n', b"print('\xe2\x82\xac')\n" ) - encoding, consumed_lines = detect_encoding(self.get_readline(lines)) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) self.assertEqual(encoding, 'iso8859-15') expected = [b'\n', b'# vim: set fileencoding=iso8859-15 :\n'] self.assertEqual(consumed_lines, expected) + def test_cookie_third_line(self): + lines = ( + b'#!/home/\xc3\xa4/bin/python\n', + b'# something\n', + b'# vim: set fileencoding=ascii :\n', + b'print(something)\n', + b'do_something(else)\n' + ) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) + self.assertEqual(encoding, 'utf-8') + self.assertEqual(consumed_lines, list(lines[:2])) + + def test_double_coding_line(self): + # If the first line matches the second line is ignored. + lines = ( + b'#coding:iso8859-15\n', + b'#coding:latin1\n', + b'print(something)\n' + ) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) + self.assertEqual(encoding, 'iso8859-15') + self.assertEqual(consumed_lines, list(lines[:1])) + + def test_double_coding_same_line(self): + lines = ( + b'#coding:iso8859-15 coding:latin1\n', + b'print(something)\n' + ) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) + self.assertEqual(encoding, 'iso8859-15') + self.assertEqual(consumed_lines, list(lines[:1])) + + def test_double_coding_utf8(self): + lines = ( + b'#coding:utf-8\n', + b'#coding:latin1\n', + b'print(something)\n' + ) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(lines)) + self.assertEqual(encoding, 'utf-8') + self.assertEqual(consumed_lines, list(lines[:1])) + + def test_nul_in_first_coding_line(self): + lines = ( + b'#coding:iso8859-15\x00\n', + b'\n', + b'\n', + b'print(something)\n' + ) + with self.assertRaisesRegex(SyntaxError, + "source code cannot contain null bytes"): + tokenize.detect_encoding(self.get_readline(lines)) + + def test_nul_in_second_coding_line(self): + lines = ( + b'#!/usr/bin/python\n', + b'#coding:iso8859-15\x00\n', + b'\n', + b'print(something)\n' + ) + with self.assertRaisesRegex(SyntaxError, + "source code cannot contain null bytes"): + tokenize.detect_encoding(self.get_readline(lines)) + def test_latin1_normalization(self): - # See get_normal_name() in tokenizer.c. + # See get_normal_name() in Parser/tokenizer/helpers.c. encodings = ("latin-1", "iso-8859-1", "iso-latin-1", "latin-1-unix", "iso-8859-1-unix", "iso-latin-1-mac") for encoding in encodings: @@ -1211,21 +1637,20 @@ def test_latin1_normalization(self): b"print(things)\n", b"do_something += 4\n") rl = self.get_readline(lines) - found, consumed_lines = detect_encoding(rl) + found, consumed_lines = tokenize.detect_encoding(rl) self.assertEqual(found, "iso-8859-1") def test_syntaxerror_latin1(self): - # Issue 14629: need to raise SyntaxError if the first + # Issue 14629: need to raise TokenError if the first # line(s) have non-UTF-8 characters lines = ( b'print("\xdf")', # Latin-1: LATIN SMALL LETTER SHARP S ) readline = self.get_readline(lines) - self.assertRaises(SyntaxError, detect_encoding, readline) - + self.assertRaises(SyntaxError, tokenize.detect_encoding, readline) def test_utf8_normalization(self): - # See get_normal_name() in tokenizer.c. + # See get_normal_name() in Parser/tokenizer/helpers.c. encodings = ("utf-8", "utf-8-mac", "utf-8-unix") for encoding in encodings: for rep in ("-", "_"): @@ -1234,39 +1659,40 @@ def test_utf8_normalization(self): b"# coding: " + enc.encode("ascii") + b"\n", b"1 + 3\n") rl = self.get_readline(lines) - found, consumed_lines = detect_encoding(rl) + found, consumed_lines = tokenize.detect_encoding(rl) self.assertEqual(found, "utf-8") def test_short_files(self): readline = self.get_readline((b'print(something)\n',)) - encoding, consumed_lines = detect_encoding(readline) + encoding, consumed_lines = tokenize.detect_encoding(readline) self.assertEqual(encoding, 'utf-8') self.assertEqual(consumed_lines, [b'print(something)\n']) - encoding, consumed_lines = detect_encoding(self.get_readline(())) + encoding, consumed_lines = tokenize.detect_encoding(self.get_readline(())) self.assertEqual(encoding, 'utf-8') self.assertEqual(consumed_lines, []) readline = self.get_readline((b'\xef\xbb\xbfprint(something)\n',)) - encoding, consumed_lines = detect_encoding(readline) + encoding, consumed_lines = tokenize.detect_encoding(readline) self.assertEqual(encoding, 'utf-8-sig') self.assertEqual(consumed_lines, [b'print(something)\n']) readline = self.get_readline((b'\xef\xbb\xbf',)) - encoding, consumed_lines = detect_encoding(readline) + encoding, consumed_lines = tokenize.detect_encoding(readline) self.assertEqual(encoding, 'utf-8-sig') self.assertEqual(consumed_lines, []) readline = self.get_readline((b'# coding: bad\n',)) - self.assertRaises(SyntaxError, detect_encoding, readline) + self.assertRaises(SyntaxError, tokenize.detect_encoding, readline) def test_false_encoding(self): # Issue 18873: "Encoding" detected in non-comment lines readline = self.get_readline((b'print("#coding=fake")',)) - encoding, consumed_lines = detect_encoding(readline) + encoding, consumed_lines = tokenize.detect_encoding(readline) self.assertEqual(encoding, 'utf-8') self.assertEqual(consumed_lines, [b'print("#coding=fake")']) + @support.thread_unsafe def test_open(self): filename = os_helper.TESTFN + '.py' self.addCleanup(os_helper.unlink, filename) @@ -1276,14 +1702,14 @@ def test_open(self): with open(filename, 'w', encoding=encoding) as fp: print("# coding: %s" % encoding, file=fp) print("print('euro:\u20ac')", file=fp) - with tokenize_open(filename) as fp: + with tokenize.open(filename) as fp: self.assertEqual(fp.encoding, encoding) self.assertEqual(fp.mode, 'r') # test BOM (no coding cookie) with open(filename, 'w', encoding='utf-8-sig') as fp: print("print('euro:\u20ac')", file=fp) - with tokenize_open(filename) as fp: + with tokenize.open(filename) as fp: self.assertEqual(fp.encoding, 'utf-8-sig') self.assertEqual(fp.mode, 'r') @@ -1310,16 +1736,16 @@ def readline(self): ins = Bunk(lines, path) # Make sure lacking a name isn't an issue. del ins.name - detect_encoding(ins.readline) + tokenize.detect_encoding(ins.readline) with self.assertRaisesRegex(SyntaxError, '.*{}'.format(path)): ins = Bunk(lines, path) - detect_encoding(ins.readline) + tokenize.detect_encoding(ins.readline) def test_open_error(self): # Issue #23840: open() must close the binary file on error m = BytesIO(b'#coding:xxx') with mock.patch('tokenize._builtin_open', return_value=m): - self.assertRaises(SyntaxError, tokenize_open, 'foobar') + self.assertRaises(SyntaxError, tokenize.open, 'foobar') self.assertTrue(m.closed) @@ -1327,17 +1753,20 @@ class TestTokenize(TestCase): def test_tokenize(self): import tokenize as tokenize_module - encoding = object() + encoding = "utf-8" encoding_used = None def mock_detect_encoding(readline): return encoding, [b'first', b'second'] - def mock__tokenize(readline, encoding): + def mock__tokenize(readline, encoding, **kwargs): nonlocal encoding_used encoding_used = encoding out = [] while True: - next_line = readline() + try: + next_line = readline() + except StopIteration: + return out if next_line: out.append(next_line) continue @@ -1352,16 +1781,16 @@ def mock_readline(): return str(counter).encode() orig_detect_encoding = tokenize_module.detect_encoding - orig__tokenize = tokenize_module._tokenize + orig_c_token = tokenize_module._generate_tokens_from_c_tokenizer tokenize_module.detect_encoding = mock_detect_encoding - tokenize_module._tokenize = mock__tokenize + tokenize_module._generate_tokens_from_c_tokenizer = mock__tokenize try: - results = tokenize(mock_readline) - self.assertEqual(list(results), + results = tokenize.tokenize(mock_readline) + self.assertEqual(list(results)[1:], [b'first', b'second', b'1', b'2', b'3', b'4']) finally: tokenize_module.detect_encoding = orig_detect_encoding - tokenize_module._tokenize = orig__tokenize + tokenize_module._generate_tokens_from_c_tokenizer = orig_c_token self.assertEqual(encoding_used, encoding) @@ -1373,23 +1802,23 @@ def test_oneline_defs(self): buf = '\n'.join(buf) # Test that 500 consequent, one-line defs is OK - toks = list(tokenize(BytesIO(buf.encode('utf-8')).readline)) + toks = list(tokenize.tokenize(BytesIO(buf.encode('utf-8')).readline)) self.assertEqual(toks[-3].string, 'OK') # [-1] is always ENDMARKER # [-2] is always NEWLINE def assertExactTypeEqual(self, opstr, *optypes): - tokens = list(tokenize(BytesIO(opstr.encode('utf-8')).readline)) + tokens = list(tokenize.tokenize(BytesIO(opstr.encode('utf-8')).readline)) num_optypes = len(optypes) self.assertEqual(len(tokens), 3 + num_optypes) - self.assertEqual(tok_name[tokens[0].exact_type], - tok_name[ENCODING]) + self.assertEqual(tokenize.tok_name[tokens[0].exact_type], + tokenize.tok_name[tokenize.ENCODING]) for i in range(num_optypes): - self.assertEqual(tok_name[tokens[i + 1].exact_type], - tok_name[optypes[i]]) - self.assertEqual(tok_name[tokens[1 + num_optypes].exact_type], - tok_name[token.NEWLINE]) - self.assertEqual(tok_name[tokens[2 + num_optypes].exact_type], - tok_name[token.ENDMARKER]) + self.assertEqual(tokenize.tok_name[tokens[i + 1].exact_type], + tokenize.tok_name[optypes[i]]) + self.assertEqual(tokenize.tok_name[tokens[1 + num_optypes].exact_type], + tokenize.tok_name[token.NEWLINE]) + self.assertEqual(tokenize.tok_name[tokens[2 + num_optypes].exact_type], + tokenize.tok_name[token.ENDMARKER]) def test_exact_type(self): self.assertExactTypeEqual('()', token.LPAR, token.RPAR) @@ -1439,11 +1868,11 @@ def test_exact_type(self): self.assertExactTypeEqual('@=', token.ATEQUAL) self.assertExactTypeEqual('a**2+b**2==c**2', - NAME, token.DOUBLESTAR, NUMBER, + tokenize.NAME, token.DOUBLESTAR, tokenize.NUMBER, token.PLUS, - NAME, token.DOUBLESTAR, NUMBER, + tokenize.NAME, token.DOUBLESTAR, tokenize.NUMBER, token.EQEQUAL, - NAME, token.DOUBLESTAR, NUMBER) + tokenize.NAME, token.DOUBLESTAR, tokenize.NUMBER) self.assertExactTypeEqual('{1, 2, 3}', token.LBRACE, token.NUMBER, token.COMMA, @@ -1463,19 +1892,55 @@ def test_pathological_trailing_whitespace(self): def test_comment_at_the_end_of_the_source_without_newline(self): # See http://bugs.python.org/issue44667 source = 'b = 1\n\n#test' - expected_tokens = [token.NAME, token.EQUAL, token.NUMBER, token.NEWLINE, token.NL, token.COMMENT] + expected_tokens = [ + tokenize.TokenInfo(type=token.ENCODING, string='utf-8', start=(0, 0), end=(0, 0), line=''), + tokenize.TokenInfo(type=token.NAME, string='b', start=(1, 0), end=(1, 1), line='b = 1\n'), + tokenize.TokenInfo(type=token.OP, string='=', start=(1, 2), end=(1, 3), line='b = 1\n'), + tokenize.TokenInfo(type=token.NUMBER, string='1', start=(1, 4), end=(1, 5), line='b = 1\n'), + tokenize.TokenInfo(type=token.NEWLINE, string='\n', start=(1, 5), end=(1, 6), line='b = 1\n'), + tokenize.TokenInfo(type=token.NL, string='\n', start=(2, 0), end=(2, 1), line='\n'), + tokenize.TokenInfo(type=token.COMMENT, string='#test', start=(3, 0), end=(3, 5), line='#test'), + tokenize.TokenInfo(type=token.NL, string='', start=(3, 5), end=(3, 6), line='#test'), + tokenize.TokenInfo(type=token.ENDMARKER, string='', start=(4, 0), end=(4, 0), line='') + ] + + tokens = list(tokenize.tokenize(BytesIO(source.encode('utf-8')).readline)) + self.assertEqual(tokens, expected_tokens) + + @unittest.expectedFailure # TODO: RUSTPYTHON; Diff is 869 characters long. Set self.maxDiff to None to see it. + def test_newline_and_space_at_the_end_of_the_source_without_newline(self): + # See https://github.com/python/cpython/issues/105435 + source = 'a\n ' + expected_tokens = [ + tokenize.TokenInfo(token.ENCODING, string='utf-8', start=(0, 0), end=(0, 0), line=''), + tokenize.TokenInfo(token.NAME, string='a', start=(1, 0), end=(1, 1), line='a\n'), + tokenize.TokenInfo(token.NEWLINE, string='\n', start=(1, 1), end=(1, 2), line='a\n'), + tokenize.TokenInfo(token.NL, string='', start=(2, 1), end=(2, 2), line=' '), + tokenize.TokenInfo(token.ENDMARKER, string='', start=(3, 0), end=(3, 0), line='') + ] + + tokens = list(tokenize.tokenize(BytesIO(source.encode('utf-8')).readline)) + self.assertEqual(tokens, expected_tokens) + + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: b'SyntaxError' not found in b'OSError: stream did not contain valid UTF-8\n' + def test_invalid_character_in_fstring_middle(self): + # See gh-103824 + script = b'''F""" + \xe5"""''' + + with os_helper.temp_dir() as temp_dir: + filename = os.path.join(temp_dir, "script.py") + with open(filename, 'wb') as file: + file.write(script) + rs, _ = run_python_until_end(filename) + self.assertIn(b"SyntaxError", rs.err) - tokens = list(tokenize(BytesIO(source.encode('utf-8')).readline)) - self.assertEqual(tok_name[tokens[0].exact_type], tok_name[ENCODING]) - for i in range(6): - self.assertEqual(tok_name[tokens[i + 1].exact_type], tok_name[expected_tokens[i]]) - self.assertEqual(tok_name[tokens[-1].exact_type], tok_name[token.ENDMARKER]) class UntokenizeTest(TestCase): def test_bad_input_order(self): # raise if previous row - u = Untokenizer() + u = tokenize.Untokenizer() u.prev_row = 2 u.prev_col = 2 with self.assertRaises(ValueError) as cm: @@ -1487,7 +1952,7 @@ def test_bad_input_order(self): def test_backslash_continuation(self): # The problem is that \ leaves no token - u = Untokenizer() + u = tokenize.Untokenizer() u.prev_row = 1 u.prev_col = 1 u.tokens = [] @@ -1499,17 +1964,33 @@ def test_backslash_continuation(self): TestRoundtrip.check_roundtrip(self, 'a\n b\n c\n \\\n c\n') def test_iter_compat(self): - u = Untokenizer() - token = (NAME, 'Hello') - tokens = [(ENCODING, 'utf-8'), token] + u = tokenize.Untokenizer() + token = (tokenize.NAME, 'Hello') + tokens = [(tokenize.ENCODING, 'utf-8'), token] u.compat(token, iter([])) self.assertEqual(u.tokens, ["Hello "]) - u = Untokenizer() + u = tokenize.Untokenizer() self.assertEqual(u.untokenize(iter([token])), 'Hello ') - u = Untokenizer() + u = tokenize.Untokenizer() self.assertEqual(u.untokenize(iter(tokens)), 'Hello ') self.assertEqual(u.encoding, 'utf-8') - self.assertEqual(untokenize(iter(tokens)), b'Hello ') + self.assertEqual(tokenize.untokenize(iter(tokens)), b'Hello ') + + +def contains_ambiguous_backslash(source): + """Return `True` if the source contains a backslash on a + line by itself. For example: + + a = (1 + \\ + ) + + Code like this cannot be untokenized exactly. This is because + the tokenizer does not produce any tokens for the line containing + the backslash and so there is no way to know its indent. + """ + pattern = re.compile(br'\n\s*\\\r?\n') + return pattern.search(source) is not None class TestRoundtrip(TestCase): @@ -1522,6 +2003,9 @@ def check_roundtrip(self, f): tokenize.untokenize(), and the latter tokenized again to 2-tuples. The test fails if the 3 pair tokenizations do not match. + If the source code can be untokenized unambiguously, the + untokenized code must match the original code exactly. + When untokenize bugs are fixed, untokenize with 5-tuples should reproduce code that does not contain a backslash continuation following spaces. A proper test should test this. @@ -1531,21 +2015,38 @@ def check_roundtrip(self, f): code = f.encode('utf-8') else: code = f.read() - f.close() readline = iter(code.splitlines(keepends=True)).__next__ - tokens5 = list(tokenize(readline)) + tokens5 = list(tokenize.tokenize(readline)) tokens2 = [tok[:2] for tok in tokens5] # Reproduce tokens2 from pairs - bytes_from2 = untokenize(tokens2) + bytes_from2 = tokenize.untokenize(tokens2) readline2 = iter(bytes_from2.splitlines(keepends=True)).__next__ - tokens2_from2 = [tok[:2] for tok in tokenize(readline2)] + tokens2_from2 = [tok[:2] for tok in tokenize.tokenize(readline2)] self.assertEqual(tokens2_from2, tokens2) # Reproduce tokens2 from 5-tuples - bytes_from5 = untokenize(tokens5) + bytes_from5 = tokenize.untokenize(tokens5) readline5 = iter(bytes_from5.splitlines(keepends=True)).__next__ - tokens2_from5 = [tok[:2] for tok in tokenize(readline5)] + tokens2_from5 = [tok[:2] for tok in tokenize.tokenize(readline5)] self.assertEqual(tokens2_from5, tokens2) + if not contains_ambiguous_backslash(code): + # The BOM does not produce a token so there is no way to preserve it. + code_without_bom = code.removeprefix(b'\xef\xbb\xbf') + readline = iter(code_without_bom.splitlines(keepends=True)).__next__ + untokenized_code = tokenize.untokenize(tokenize.tokenize(readline)) + self.assertEqual(code_without_bom, untokenized_code) + + def check_line_extraction(self, f): + if isinstance(f, str): + code = f.encode('utf-8') + else: + code = f.read() + readline = iter(code.splitlines(keepends=True)).__next__ + for tok in tokenize.tokenize(readline): + if tok.type in {tokenize.ENCODING, tokenize.ENDMARKER}: + continue + self.assertEqual(tok.string, tok.line[tok.start[1]: tok.end[1]]) + def test_roundtrip(self): # There are some standard formatting practices that are easy to get right. @@ -1561,7 +2062,7 @@ def test_roundtrip(self): self.check_roundtrip("if x == 1 : \n" " print(x)\n") - fn = support.findfile("tokenize_tests.txt") + fn = support.findfile("tokenize_tests.txt", subdir="tokenizedata") with open(fn, 'rb') as f: self.check_roundtrip(f) self.check_roundtrip("if x == 1:\n" @@ -1585,6 +2086,67 @@ def test_roundtrip(self): " print('Can not import' # comment2\n)" "else: print('Loaded')\n") + self.check_roundtrip("f'\\N{EXCLAMATION MARK}'") + self.check_roundtrip(r"f'\\N{SNAKE}'") + self.check_roundtrip(r"f'\\N{{SNAKE}}'") + self.check_roundtrip(r"f'\N{SNAKE}'") + self.check_roundtrip(r"f'\\\N{SNAKE}'") + self.check_roundtrip(r"f'\\\\\N{SNAKE}'") + self.check_roundtrip(r"f'\\\\\\\N{SNAKE}'") + + self.check_roundtrip(r"f'\\N{1}'") + self.check_roundtrip(r"f'\\\\N{2}'") + self.check_roundtrip(r"f'\\\\\\N{3}'") + self.check_roundtrip(r"f'\\\\\\\\N{4}'") + + self.check_roundtrip(r"f'\\N{{'") + self.check_roundtrip(r"f'\\\\N{{'") + self.check_roundtrip(r"f'\\\\\\N{{'") + self.check_roundtrip(r"f'\\\\\\\\N{{'") + + self.check_roundtrip(r"f'\n{{foo}}'") + self.check_roundtrip(r"f'\\n{{foo}}'") + self.check_roundtrip(r"f'\\\n{{foo}}'") + self.check_roundtrip(r"f'\\\\n{{foo}}'") + + self.check_roundtrip(r"f'\t{{foo}}'") + self.check_roundtrip(r"f'\\t{{foo}}'") + self.check_roundtrip(r"f'\\\t{{foo}}'") + self.check_roundtrip(r"f'\\\\t{{foo}}'") + + self.check_roundtrip(r"rf'\t{{foo}}'") + self.check_roundtrip(r"rf'\\t{{foo}}'") + self.check_roundtrip(r"rf'\\\t{{foo}}'") + self.check_roundtrip(r"rf'\\\\t{{foo}}'") + + self.check_roundtrip(r"rf'\{{foo}}'") + self.check_roundtrip(r"f'\\{{foo}}'") + self.check_roundtrip(r"rf'\\\{{foo}}'") + self.check_roundtrip(r"f'\\\\{{foo}}'") + cases = [ + """ +if 1: + "foo" +"bar" +""", + """ +if 1: + ("foo" + "bar") +""", + """ +if 1: + "foo" + "bar" +""" ] + for case in cases: + self.check_roundtrip(case) + + self.check_roundtrip(r"t'{ {}}'") + self.check_roundtrip(r"t'{f'{ {}}'}{ {}}'") + self.check_roundtrip(r"f'{t'{ {}}'}{ {}}'") + + def test_continuation(self): # Balancing continuation self.check_roundtrip("a = (3,4, \n" @@ -1611,26 +2173,14 @@ def test_string_concatenation(self): # Two string literals on the same line self.check_roundtrip("'' ''") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_random_files(self): # Test roundtrip on random python modules. # pass the '-ucpu' option to process the full directory. import glob, random - fn = support.findfile("tokenize_tests.txt") - tempdir = os.path.dirname(fn) or os.curdir + tempdir = os.path.dirname(__file__) or os.curdir testfiles = glob.glob(os.path.join(glob.escape(tempdir), "test*.py")) - # Tokenize is broken on test_pep3131.py because regular expressions are - # broken on the obscure unicode identifiers in it. *sigh* - # With roundtrip extended to test the 5-tuple mode of untokenize, - # 7 more testfiles fail. Remove them also until the failure is diagnosed. - - testfiles.remove(os.path.join(tempdir, "test_unicode_identifiers.py")) - for f in ('buffer', 'builtin', 'fileio', 'inspect', 'os', 'platform', 'sys'): - testfiles.remove(os.path.join(tempdir, "test_%s.py") % f) - if not support.is_resource_enabled("cpu"): testfiles = random.sample(testfiles, 10) @@ -1640,12 +2190,13 @@ def test_random_files(self): with open(testfile, 'rb') as f: with self.subTest(file=testfile): self.check_roundtrip(f) + self.check_line_extraction(f) def roundtrip(self, code): if isinstance(code, str): code = code.encode('utf-8') - return untokenize(tokenize(BytesIO(code).readline)).decode('utf-8') + return tokenize.untokenize(tokenize.tokenize(BytesIO(code).readline)).decode('utf-8') def test_indentation_semantics_retained(self): """ @@ -1658,5 +2209,1279 @@ def test_indentation_semantics_retained(self): self.check_roundtrip(code) +class InvalidPythonTests(TestCase): + def test_number_followed_by_name(self): + # See issue #gh-105549 + source = "2sin(x)" + expected_tokens = [ + tokenize.TokenInfo(type=token.NUMBER, string='2', start=(1, 0), end=(1, 1), line='2sin(x)'), + tokenize.TokenInfo(type=token.NAME, string='sin', start=(1, 1), end=(1, 4), line='2sin(x)'), + tokenize.TokenInfo(type=token.OP, string='(', start=(1, 4), end=(1, 5), line='2sin(x)'), + tokenize.TokenInfo(type=token.NAME, string='x', start=(1, 5), end=(1, 6), line='2sin(x)'), + tokenize.TokenInfo(type=token.OP, string=')', start=(1, 6), end=(1, 7), line='2sin(x)'), + tokenize.TokenInfo(type=token.NEWLINE, string='', start=(1, 7), end=(1, 8), line='2sin(x)'), + tokenize.TokenInfo(type=token.ENDMARKER, string='', start=(2, 0), end=(2, 0), line='') + ] + + tokens = list(tokenize.generate_tokens(StringIO(source).readline)) + self.assertEqual(tokens, expected_tokens) + + @unittest.expectedFailure # TODO: RUSTPYTHON; Diff is 855 characters long. Set self.maxDiff to None to see it. + def test_number_starting_with_zero(self): + source = "01234" + expected_tokens = [ + tokenize.TokenInfo(type=token.NUMBER, string='01234', start=(1, 0), end=(1, 5), line='01234'), + tokenize.TokenInfo(type=token.NEWLINE, string='', start=(1, 5), end=(1, 6), line='01234'), + tokenize.TokenInfo(type=token.ENDMARKER, string='', start=(2, 0), end=(2, 0), line='') + ] + + tokens = list(tokenize.generate_tokens(StringIO(source).readline)) + self.assertEqual(tokens, expected_tokens) + +class CTokenizeTest(TestCase): + def check_tokenize(self, s, expected): + # Format the tokens in s in a table format. + # The ENDMARKER and final NEWLINE are omitted. + f = StringIO(s) + with self.subTest(source=s): + result = stringify_tokens_from_source( + tokenize._generate_tokens_from_c_tokenizer(f.readline), s + ) + self.assertEqual(result, expected.rstrip().splitlines()) + + def test_encoding(self): + def readline(encoding): + yield "1+1".encode(encoding) + + expected = [ + tokenize.TokenInfo(type=tokenize.NUMBER, string='1', start=(1, 0), end=(1, 1), line='1+1'), + tokenize.TokenInfo(type=tokenize.OP, string='+', start=(1, 1), end=(1, 2), line='1+1'), + tokenize.TokenInfo(type=tokenize.NUMBER, string='1', start=(1, 2), end=(1, 3), line='1+1'), + tokenize.TokenInfo(type=tokenize.NEWLINE, string='', start=(1, 3), end=(1, 4), line='1+1'), + tokenize.TokenInfo(type=tokenize.ENDMARKER, string='', start=(2, 0), end=(2, 0), line='') + ] + for encoding in ["utf-8", "latin-1", "utf-16"]: + with self.subTest(encoding=encoding): + tokens = list(tokenize._generate_tokens_from_c_tokenizer( + readline(encoding).__next__, + extra_tokens=True, + encoding=encoding, + )) + self.assertEqual(tokens, expected) + + def test_int(self): + + self.check_tokenize('0xff <= 255', """\ + NUMBER '0xff' (1, 0) (1, 4) + LESSEQUAL '<=' (1, 5) (1, 7) + NUMBER '255' (1, 8) (1, 11) + """) + + self.check_tokenize('0b10 <= 255', """\ + NUMBER '0b10' (1, 0) (1, 4) + LESSEQUAL '<=' (1, 5) (1, 7) + NUMBER '255' (1, 8) (1, 11) + """) + + self.check_tokenize('0o123 <= 0O123', """\ + NUMBER '0o123' (1, 0) (1, 5) + LESSEQUAL '<=' (1, 6) (1, 8) + NUMBER '0O123' (1, 9) (1, 14) + """) + + self.check_tokenize('1234567 > ~0x15', """\ + NUMBER '1234567' (1, 0) (1, 7) + GREATER '>' (1, 8) (1, 9) + TILDE '~' (1, 10) (1, 11) + NUMBER '0x15' (1, 11) (1, 15) + """) + + self.check_tokenize('2134568 != 1231515', """\ + NUMBER '2134568' (1, 0) (1, 7) + NOTEQUAL '!=' (1, 8) (1, 10) + NUMBER '1231515' (1, 11) (1, 18) + """) + + self.check_tokenize('(-124561-1) & 200000000', """\ + LPAR '(' (1, 0) (1, 1) + MINUS '-' (1, 1) (1, 2) + NUMBER '124561' (1, 2) (1, 8) + MINUS '-' (1, 8) (1, 9) + NUMBER '1' (1, 9) (1, 10) + RPAR ')' (1, 10) (1, 11) + AMPER '&' (1, 12) (1, 13) + NUMBER '200000000' (1, 14) (1, 23) + """) + + self.check_tokenize('0xdeadbeef != -1', """\ + NUMBER '0xdeadbeef' (1, 0) (1, 10) + NOTEQUAL '!=' (1, 11) (1, 13) + MINUS '-' (1, 14) (1, 15) + NUMBER '1' (1, 15) (1, 16) + """) + + self.check_tokenize('0xdeadc0de & 12345', """\ + NUMBER '0xdeadc0de' (1, 0) (1, 10) + AMPER '&' (1, 11) (1, 12) + NUMBER '12345' (1, 13) (1, 18) + """) + + self.check_tokenize('0xFF & 0x15 | 1234', """\ + NUMBER '0xFF' (1, 0) (1, 4) + AMPER '&' (1, 5) (1, 6) + NUMBER '0x15' (1, 7) (1, 11) + VBAR '|' (1, 12) (1, 13) + NUMBER '1234' (1, 14) (1, 18) + """) + + def test_float(self): + + self.check_tokenize('x = 3.14159', """\ + NAME 'x' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + NUMBER '3.14159' (1, 4) (1, 11) + """) + + self.check_tokenize('x = 314159.', """\ + NAME 'x' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + NUMBER '314159.' (1, 4) (1, 11) + """) + + self.check_tokenize('x = .314159', """\ + NAME 'x' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + NUMBER '.314159' (1, 4) (1, 11) + """) + + self.check_tokenize('x = 3e14159', """\ + NAME 'x' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + NUMBER '3e14159' (1, 4) (1, 11) + """) + + self.check_tokenize('x = 3E123', """\ + NAME 'x' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + NUMBER '3E123' (1, 4) (1, 9) + """) + + self.check_tokenize('x+y = 3e-1230', """\ + NAME 'x' (1, 0) (1, 1) + PLUS '+' (1, 1) (1, 2) + NAME 'y' (1, 2) (1, 3) + EQUAL '=' (1, 4) (1, 5) + NUMBER '3e-1230' (1, 6) (1, 13) + """) + + self.check_tokenize('x = 3.14e159', """\ + NAME 'x' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + NUMBER '3.14e159' (1, 4) (1, 12) + """) + + def test_string(self): + + self.check_tokenize('x = \'\'; y = ""', """\ + NAME 'x' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + STRING "''" (1, 4) (1, 6) + SEMI ';' (1, 6) (1, 7) + NAME 'y' (1, 8) (1, 9) + EQUAL '=' (1, 10) (1, 11) + STRING '""' (1, 12) (1, 14) + """) + + self.check_tokenize('x = \'"\'; y = "\'"', """\ + NAME 'x' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + STRING '\\'"\\'' (1, 4) (1, 7) + SEMI ';' (1, 7) (1, 8) + NAME 'y' (1, 9) (1, 10) + EQUAL '=' (1, 11) (1, 12) + STRING '"\\'"' (1, 13) (1, 16) + """) + + self.check_tokenize('x = "doesn\'t "shrink", does it"', """\ + NAME 'x' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + STRING '"doesn\\'t "' (1, 4) (1, 14) + NAME 'shrink' (1, 14) (1, 20) + STRING '", does it"' (1, 20) (1, 31) + """) + + self.check_tokenize("x = 'abc' + 'ABC'", """\ + NAME 'x' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + STRING "'abc'" (1, 4) (1, 9) + PLUS '+' (1, 10) (1, 11) + STRING "'ABC'" (1, 12) (1, 17) + """) + + self.check_tokenize('y = "ABC" + "ABC"', """\ + NAME 'y' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + STRING '"ABC"' (1, 4) (1, 9) + PLUS '+' (1, 10) (1, 11) + STRING '"ABC"' (1, 12) (1, 17) + """) + + self.check_tokenize("x = r'abc' + r'ABC' + R'ABC' + R'ABC'", """\ + NAME 'x' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + STRING "r'abc'" (1, 4) (1, 10) + PLUS '+' (1, 11) (1, 12) + STRING "r'ABC'" (1, 13) (1, 19) + PLUS '+' (1, 20) (1, 21) + STRING "R'ABC'" (1, 22) (1, 28) + PLUS '+' (1, 29) (1, 30) + STRING "R'ABC'" (1, 31) (1, 37) + """) + + self.check_tokenize('y = r"abc" + r"ABC" + R"ABC" + R"ABC"', """\ + NAME 'y' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + STRING 'r"abc"' (1, 4) (1, 10) + PLUS '+' (1, 11) (1, 12) + STRING 'r"ABC"' (1, 13) (1, 19) + PLUS '+' (1, 20) (1, 21) + STRING 'R"ABC"' (1, 22) (1, 28) + PLUS '+' (1, 29) (1, 30) + STRING 'R"ABC"' (1, 31) (1, 37) + """) + + self.check_tokenize("u'abc' + U'abc'", """\ + STRING "u'abc'" (1, 0) (1, 6) + PLUS '+' (1, 7) (1, 8) + STRING "U'abc'" (1, 9) (1, 15) + """) + + self.check_tokenize('u"abc" + U"abc"', """\ + STRING 'u"abc"' (1, 0) (1, 6) + PLUS '+' (1, 7) (1, 8) + STRING 'U"abc"' (1, 9) (1, 15) + """) + + self.check_tokenize("b'abc' + B'abc'", """\ + STRING "b'abc'" (1, 0) (1, 6) + PLUS '+' (1, 7) (1, 8) + STRING "B'abc'" (1, 9) (1, 15) + """) + + self.check_tokenize('b"abc" + B"abc"', """\ + STRING 'b"abc"' (1, 0) (1, 6) + PLUS '+' (1, 7) (1, 8) + STRING 'B"abc"' (1, 9) (1, 15) + """) + + self.check_tokenize("br'abc' + bR'abc' + Br'abc' + BR'abc'", """\ + STRING "br'abc'" (1, 0) (1, 7) + PLUS '+' (1, 8) (1, 9) + STRING "bR'abc'" (1, 10) (1, 17) + PLUS '+' (1, 18) (1, 19) + STRING "Br'abc'" (1, 20) (1, 27) + PLUS '+' (1, 28) (1, 29) + STRING "BR'abc'" (1, 30) (1, 37) + """) + + self.check_tokenize('br"abc" + bR"abc" + Br"abc" + BR"abc"', """\ + STRING 'br"abc"' (1, 0) (1, 7) + PLUS '+' (1, 8) (1, 9) + STRING 'bR"abc"' (1, 10) (1, 17) + PLUS '+' (1, 18) (1, 19) + STRING 'Br"abc"' (1, 20) (1, 27) + PLUS '+' (1, 28) (1, 29) + STRING 'BR"abc"' (1, 30) (1, 37) + """) + + self.check_tokenize("rb'abc' + rB'abc' + Rb'abc' + RB'abc'", """\ + STRING "rb'abc'" (1, 0) (1, 7) + PLUS '+' (1, 8) (1, 9) + STRING "rB'abc'" (1, 10) (1, 17) + PLUS '+' (1, 18) (1, 19) + STRING "Rb'abc'" (1, 20) (1, 27) + PLUS '+' (1, 28) (1, 29) + STRING "RB'abc'" (1, 30) (1, 37) + """) + + self.check_tokenize('rb"abc" + rB"abc" + Rb"abc" + RB"abc"', """\ + STRING 'rb"abc"' (1, 0) (1, 7) + PLUS '+' (1, 8) (1, 9) + STRING 'rB"abc"' (1, 10) (1, 17) + PLUS '+' (1, 18) (1, 19) + STRING 'Rb"abc"' (1, 20) (1, 27) + PLUS '+' (1, 28) (1, 29) + STRING 'RB"abc"' (1, 30) (1, 37) + """) + + self.check_tokenize('"a\\\nde\\\nfg"', """\ + STRING '"a\\\\\\nde\\\\\\nfg"\' (1, 0) (3, 3) + """) + + self.check_tokenize('u"a\\\nde"', """\ + STRING 'u"a\\\\\\nde"\' (1, 0) (2, 3) + """) + + self.check_tokenize('rb"a\\\nd"', """\ + STRING 'rb"a\\\\\\nd"\' (1, 0) (2, 2) + """) + + self.check_tokenize(r'"""a\ +b"""', """\ + STRING '\"\""a\\\\\\nb\"\""' (1, 0) (2, 4) + """) + self.check_tokenize(r'u"""a\ +b"""', """\ + STRING 'u\"\""a\\\\\\nb\"\""' (1, 0) (2, 4) + """) + self.check_tokenize(r'rb"""a\ +b\ +c"""', """\ + STRING 'rb"\""a\\\\\\nb\\\\\\nc"\""' (1, 0) (3, 4) + """) + + self.check_tokenize(r'"hola\\\r\ndfgf"', """\ + STRING \'"hola\\\\\\\\\\\\r\\\\ndfgf"\' (1, 0) (1, 16) + """) + + self.check_tokenize('f"abc"', """\ + FSTRING_START 'f"' (1, 0) (1, 2) + FSTRING_MIDDLE 'abc' (1, 2) (1, 5) + FSTRING_END '"' (1, 5) (1, 6) + """) + + self.check_tokenize('fR"a{b}c"', """\ + FSTRING_START 'fR"' (1, 0) (1, 3) + FSTRING_MIDDLE 'a' (1, 3) (1, 4) + LBRACE '{' (1, 4) (1, 5) + NAME 'b' (1, 5) (1, 6) + RBRACE '}' (1, 6) (1, 7) + FSTRING_MIDDLE 'c' (1, 7) (1, 8) + FSTRING_END '"' (1, 8) (1, 9) + """) + + self.check_tokenize('f"""abc"""', """\ + FSTRING_START 'f\"""' (1, 0) (1, 4) + FSTRING_MIDDLE 'abc' (1, 4) (1, 7) + FSTRING_END '\"""' (1, 7) (1, 10) + """) + + self.check_tokenize(r'f"abc\ +def"', """\ + FSTRING_START \'f"\' (1, 0) (1, 2) + FSTRING_MIDDLE 'abc\\\\\\ndef' (1, 2) (2, 3) + FSTRING_END '"' (2, 3) (2, 4) + """) + + self.check_tokenize('''\ +f"{ +a}"''', """\ + FSTRING_START 'f"' (1, 0) (1, 2) + LBRACE '{' (1, 2) (1, 3) + NAME 'a' (2, 0) (2, 1) + RBRACE '}' (2, 1) (2, 2) + FSTRING_END '"' (2, 2) (2, 3) + """) + + self.check_tokenize(r'Rf"abc\ +def"', """\ + FSTRING_START 'Rf"' (1, 0) (1, 3) + FSTRING_MIDDLE 'abc\\\\\\ndef' (1, 3) (2, 3) + FSTRING_END '"' (2, 3) (2, 4) + """) + + self.check_tokenize(r'f"hola\\\r\ndfgf"', """\ + FSTRING_START \'f"\' (1, 0) (1, 2) + FSTRING_MIDDLE 'hola\\\\\\\\\\\\r\\\\ndfgf' (1, 2) (1, 16) + FSTRING_END \'"\' (1, 16) (1, 17) + """) + + self.check_tokenize("""\ +f'''__{ + x:a +}__'''""", """\ + FSTRING_START "f'''" (1, 0) (1, 4) + FSTRING_MIDDLE '__' (1, 4) (1, 6) + LBRACE '{' (1, 6) (1, 7) + NAME 'x' (2, 4) (2, 5) + COLON ':' (2, 5) (2, 6) + FSTRING_MIDDLE 'a\\n' (2, 6) (3, 0) + RBRACE '}' (3, 0) (3, 1) + FSTRING_MIDDLE '__' (3, 1) (3, 3) + FSTRING_END "'''" (3, 3) (3, 6) + """) + + self.check_tokenize("""\ +f'''__{ + x:a + b + c + d +}__'''""", """\ + FSTRING_START "f'''" (1, 0) (1, 4) + FSTRING_MIDDLE '__' (1, 4) (1, 6) + LBRACE '{' (1, 6) (1, 7) + NAME 'x' (2, 4) (2, 5) + COLON ':' (2, 5) (2, 6) + FSTRING_MIDDLE 'a\\n b\\n c\\n d\\n' (2, 6) (6, 0) + RBRACE '}' (6, 0) (6, 1) + FSTRING_MIDDLE '__' (6, 1) (6, 3) + FSTRING_END "'''" (6, 3) (6, 6) + """) + + def test_function(self): + + self.check_tokenize('def d22(a, b, c=2, d=2, *k): pass', """\ + NAME 'def' (1, 0) (1, 3) + NAME 'd22' (1, 4) (1, 7) + LPAR '(' (1, 7) (1, 8) + NAME 'a' (1, 8) (1, 9) + COMMA ',' (1, 9) (1, 10) + NAME 'b' (1, 11) (1, 12) + COMMA ',' (1, 12) (1, 13) + NAME 'c' (1, 14) (1, 15) + EQUAL '=' (1, 15) (1, 16) + NUMBER '2' (1, 16) (1, 17) + COMMA ',' (1, 17) (1, 18) + NAME 'd' (1, 19) (1, 20) + EQUAL '=' (1, 20) (1, 21) + NUMBER '2' (1, 21) (1, 22) + COMMA ',' (1, 22) (1, 23) + STAR '*' (1, 24) (1, 25) + NAME 'k' (1, 25) (1, 26) + RPAR ')' (1, 26) (1, 27) + COLON ':' (1, 27) (1, 28) + NAME 'pass' (1, 29) (1, 33) + """) + + self.check_tokenize('def d01v_(a=1, *k, **w): pass', """\ + NAME 'def' (1, 0) (1, 3) + NAME 'd01v_' (1, 4) (1, 9) + LPAR '(' (1, 9) (1, 10) + NAME 'a' (1, 10) (1, 11) + EQUAL '=' (1, 11) (1, 12) + NUMBER '1' (1, 12) (1, 13) + COMMA ',' (1, 13) (1, 14) + STAR '*' (1, 15) (1, 16) + NAME 'k' (1, 16) (1, 17) + COMMA ',' (1, 17) (1, 18) + DOUBLESTAR '**' (1, 19) (1, 21) + NAME 'w' (1, 21) (1, 22) + RPAR ')' (1, 22) (1, 23) + COLON ':' (1, 23) (1, 24) + NAME 'pass' (1, 25) (1, 29) + """) + + self.check_tokenize('def d23(a: str, b: int=3) -> int: pass', """\ + NAME 'def' (1, 0) (1, 3) + NAME 'd23' (1, 4) (1, 7) + LPAR '(' (1, 7) (1, 8) + NAME 'a' (1, 8) (1, 9) + COLON ':' (1, 9) (1, 10) + NAME 'str' (1, 11) (1, 14) + COMMA ',' (1, 14) (1, 15) + NAME 'b' (1, 16) (1, 17) + COLON ':' (1, 17) (1, 18) + NAME 'int' (1, 19) (1, 22) + EQUAL '=' (1, 22) (1, 23) + NUMBER '3' (1, 23) (1, 24) + RPAR ')' (1, 24) (1, 25) + RARROW '->' (1, 26) (1, 28) + NAME 'int' (1, 29) (1, 32) + COLON ':' (1, 32) (1, 33) + NAME 'pass' (1, 34) (1, 38) + """) + + def test_comparison(self): + + self.check_tokenize("if 1 < 1 > 1 == 1 >= 5 <= 0x15 <= 0x12 != " + "1 and 5 in 1 not in 1 is 1 or 5 is not 1: pass", """\ + NAME 'if' (1, 0) (1, 2) + NUMBER '1' (1, 3) (1, 4) + LESS '<' (1, 5) (1, 6) + NUMBER '1' (1, 7) (1, 8) + GREATER '>' (1, 9) (1, 10) + NUMBER '1' (1, 11) (1, 12) + EQEQUAL '==' (1, 13) (1, 15) + NUMBER '1' (1, 16) (1, 17) + GREATEREQUAL '>=' (1, 18) (1, 20) + NUMBER '5' (1, 21) (1, 22) + LESSEQUAL '<=' (1, 23) (1, 25) + NUMBER '0x15' (1, 26) (1, 30) + LESSEQUAL '<=' (1, 31) (1, 33) + NUMBER '0x12' (1, 34) (1, 38) + NOTEQUAL '!=' (1, 39) (1, 41) + NUMBER '1' (1, 42) (1, 43) + NAME 'and' (1, 44) (1, 47) + NUMBER '5' (1, 48) (1, 49) + NAME 'in' (1, 50) (1, 52) + NUMBER '1' (1, 53) (1, 54) + NAME 'not' (1, 55) (1, 58) + NAME 'in' (1, 59) (1, 61) + NUMBER '1' (1, 62) (1, 63) + NAME 'is' (1, 64) (1, 66) + NUMBER '1' (1, 67) (1, 68) + NAME 'or' (1, 69) (1, 71) + NUMBER '5' (1, 72) (1, 73) + NAME 'is' (1, 74) (1, 76) + NAME 'not' (1, 77) (1, 80) + NUMBER '1' (1, 81) (1, 82) + COLON ':' (1, 82) (1, 83) + NAME 'pass' (1, 84) (1, 88) + """) + + def test_additive(self): + + self.check_tokenize('x = 1 - y + 15 - 1 + 0x124 + z + a[5]', """\ + NAME 'x' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + NUMBER '1' (1, 4) (1, 5) + MINUS '-' (1, 6) (1, 7) + NAME 'y' (1, 8) (1, 9) + PLUS '+' (1, 10) (1, 11) + NUMBER '15' (1, 12) (1, 14) + MINUS '-' (1, 15) (1, 16) + NUMBER '1' (1, 17) (1, 18) + PLUS '+' (1, 19) (1, 20) + NUMBER '0x124' (1, 21) (1, 26) + PLUS '+' (1, 27) (1, 28) + NAME 'z' (1, 29) (1, 30) + PLUS '+' (1, 31) (1, 32) + NAME 'a' (1, 33) (1, 34) + LSQB '[' (1, 34) (1, 35) + NUMBER '5' (1, 35) (1, 36) + RSQB ']' (1, 36) (1, 37) + """) + + def test_multiplicative(self): + + self.check_tokenize('x = 1//1*1/5*12%0x12@42', """\ + NAME 'x' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + NUMBER '1' (1, 4) (1, 5) + DOUBLESLASH '//' (1, 5) (1, 7) + NUMBER '1' (1, 7) (1, 8) + STAR '*' (1, 8) (1, 9) + NUMBER '1' (1, 9) (1, 10) + SLASH '/' (1, 10) (1, 11) + NUMBER '5' (1, 11) (1, 12) + STAR '*' (1, 12) (1, 13) + NUMBER '12' (1, 13) (1, 15) + PERCENT '%' (1, 15) (1, 16) + NUMBER '0x12' (1, 16) (1, 20) + AT '@' (1, 20) (1, 21) + NUMBER '42' (1, 21) (1, 23) + """) + + def test_unary(self): + + self.check_tokenize('~1 ^ 1 & 1 |1 ^ -1', """\ + TILDE '~' (1, 0) (1, 1) + NUMBER '1' (1, 1) (1, 2) + CIRCUMFLEX '^' (1, 3) (1, 4) + NUMBER '1' (1, 5) (1, 6) + AMPER '&' (1, 7) (1, 8) + NUMBER '1' (1, 9) (1, 10) + VBAR '|' (1, 11) (1, 12) + NUMBER '1' (1, 12) (1, 13) + CIRCUMFLEX '^' (1, 14) (1, 15) + MINUS '-' (1, 16) (1, 17) + NUMBER '1' (1, 17) (1, 18) + """) + + self.check_tokenize('-1*1/1+1*1//1 - ---1**1', """\ + MINUS '-' (1, 0) (1, 1) + NUMBER '1' (1, 1) (1, 2) + STAR '*' (1, 2) (1, 3) + NUMBER '1' (1, 3) (1, 4) + SLASH '/' (1, 4) (1, 5) + NUMBER '1' (1, 5) (1, 6) + PLUS '+' (1, 6) (1, 7) + NUMBER '1' (1, 7) (1, 8) + STAR '*' (1, 8) (1, 9) + NUMBER '1' (1, 9) (1, 10) + DOUBLESLASH '//' (1, 10) (1, 12) + NUMBER '1' (1, 12) (1, 13) + MINUS '-' (1, 14) (1, 15) + MINUS '-' (1, 16) (1, 17) + MINUS '-' (1, 17) (1, 18) + MINUS '-' (1, 18) (1, 19) + NUMBER '1' (1, 19) (1, 20) + DOUBLESTAR '**' (1, 20) (1, 22) + NUMBER '1' (1, 22) (1, 23) + """) + + def test_selector(self): + + self.check_tokenize("import sys, time\nx = sys.modules['time'].time()", """\ + NAME 'import' (1, 0) (1, 6) + NAME 'sys' (1, 7) (1, 10) + COMMA ',' (1, 10) (1, 11) + NAME 'time' (1, 12) (1, 16) + NEWLINE '' (1, 16) (1, 16) + NAME 'x' (2, 0) (2, 1) + EQUAL '=' (2, 2) (2, 3) + NAME 'sys' (2, 4) (2, 7) + DOT '.' (2, 7) (2, 8) + NAME 'modules' (2, 8) (2, 15) + LSQB '[' (2, 15) (2, 16) + STRING "'time'" (2, 16) (2, 22) + RSQB ']' (2, 22) (2, 23) + DOT '.' (2, 23) (2, 24) + NAME 'time' (2, 24) (2, 28) + LPAR '(' (2, 28) (2, 29) + RPAR ')' (2, 29) (2, 30) + """) + + def test_method(self): + + self.check_tokenize('@staticmethod\ndef foo(x,y): pass', """\ + AT '@' (1, 0) (1, 1) + NAME 'staticmethod' (1, 1) (1, 13) + NEWLINE '' (1, 13) (1, 13) + NAME 'def' (2, 0) (2, 3) + NAME 'foo' (2, 4) (2, 7) + LPAR '(' (2, 7) (2, 8) + NAME 'x' (2, 8) (2, 9) + COMMA ',' (2, 9) (2, 10) + NAME 'y' (2, 10) (2, 11) + RPAR ')' (2, 11) (2, 12) + COLON ':' (2, 12) (2, 13) + NAME 'pass' (2, 14) (2, 18) + """) + + def test_tabs(self): + + self.check_tokenize('@staticmethod\ndef foo(x,y): pass', """\ + AT '@' (1, 0) (1, 1) + NAME 'staticmethod' (1, 1) (1, 13) + NEWLINE '' (1, 13) (1, 13) + NAME 'def' (2, 0) (2, 3) + NAME 'foo' (2, 4) (2, 7) + LPAR '(' (2, 7) (2, 8) + NAME 'x' (2, 8) (2, 9) + COMMA ',' (2, 9) (2, 10) + NAME 'y' (2, 10) (2, 11) + RPAR ')' (2, 11) (2, 12) + COLON ':' (2, 12) (2, 13) + NAME 'pass' (2, 14) (2, 18) + """) + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_async(self): + + self.check_tokenize('async = 1', """\ + NAME 'async' (1, 0) (1, 5) + EQUAL '=' (1, 6) (1, 7) + NUMBER '1' (1, 8) (1, 9) + """) + + self.check_tokenize('a = (async = 1)', """\ + NAME 'a' (1, 0) (1, 1) + EQUAL '=' (1, 2) (1, 3) + LPAR '(' (1, 4) (1, 5) + NAME 'async' (1, 5) (1, 10) + EQUAL '=' (1, 11) (1, 12) + NUMBER '1' (1, 13) (1, 14) + RPAR ')' (1, 14) (1, 15) + """) + + self.check_tokenize('async()', """\ + NAME 'async' (1, 0) (1, 5) + LPAR '(' (1, 5) (1, 6) + RPAR ')' (1, 6) (1, 7) + """) + + self.check_tokenize('class async(Bar):pass', """\ + NAME 'class' (1, 0) (1, 5) + NAME 'async' (1, 6) (1, 11) + LPAR '(' (1, 11) (1, 12) + NAME 'Bar' (1, 12) (1, 15) + RPAR ')' (1, 15) (1, 16) + COLON ':' (1, 16) (1, 17) + NAME 'pass' (1, 17) (1, 21) + """) + + self.check_tokenize('class async:pass', """\ + NAME 'class' (1, 0) (1, 5) + NAME 'async' (1, 6) (1, 11) + COLON ':' (1, 11) (1, 12) + NAME 'pass' (1, 12) (1, 16) + """) + + self.check_tokenize('await = 1', """\ + NAME 'await' (1, 0) (1, 5) + EQUAL '=' (1, 6) (1, 7) + NUMBER '1' (1, 8) (1, 9) + """) + + self.check_tokenize('foo.async', """\ + NAME 'foo' (1, 0) (1, 3) + DOT '.' (1, 3) (1, 4) + NAME 'async' (1, 4) (1, 9) + """) + + self.check_tokenize('async for a in b: pass', """\ + NAME 'async' (1, 0) (1, 5) + NAME 'for' (1, 6) (1, 9) + NAME 'a' (1, 10) (1, 11) + NAME 'in' (1, 12) (1, 14) + NAME 'b' (1, 15) (1, 16) + COLON ':' (1, 16) (1, 17) + NAME 'pass' (1, 18) (1, 22) + """) + + self.check_tokenize('async with a as b: pass', """\ + NAME 'async' (1, 0) (1, 5) + NAME 'with' (1, 6) (1, 10) + NAME 'a' (1, 11) (1, 12) + NAME 'as' (1, 13) (1, 15) + NAME 'b' (1, 16) (1, 17) + COLON ':' (1, 17) (1, 18) + NAME 'pass' (1, 19) (1, 23) + """) + + self.check_tokenize('async.foo', """\ + NAME 'async' (1, 0) (1, 5) + DOT '.' (1, 5) (1, 6) + NAME 'foo' (1, 6) (1, 9) + """) + + self.check_tokenize('async', """\ + NAME 'async' (1, 0) (1, 5) + """) + + self.check_tokenize('async\n#comment\nawait', """\ + NAME 'async' (1, 0) (1, 5) + NEWLINE '' (1, 5) (1, 5) + NAME 'await' (3, 0) (3, 5) + """) + + self.check_tokenize('async\n...\nawait', """\ + NAME 'async' (1, 0) (1, 5) + NEWLINE '' (1, 5) (1, 5) + ELLIPSIS '...' (2, 0) (2, 3) + NEWLINE '' (2, 3) (2, 3) + NAME 'await' (3, 0) (3, 5) + """) + + self.check_tokenize('async\nawait', """\ + NAME 'async' (1, 0) (1, 5) + NEWLINE '' (1, 5) (1, 5) + NAME 'await' (2, 0) (2, 5) + """) + + self.check_tokenize('foo.async + 1', """\ + NAME 'foo' (1, 0) (1, 3) + DOT '.' (1, 3) (1, 4) + NAME 'async' (1, 4) (1, 9) + PLUS '+' (1, 10) (1, 11) + NUMBER '1' (1, 12) (1, 13) + """) + + self.check_tokenize('async def foo(): pass', """\ + NAME 'async' (1, 0) (1, 5) + NAME 'def' (1, 6) (1, 9) + NAME 'foo' (1, 10) (1, 13) + LPAR '(' (1, 13) (1, 14) + RPAR ')' (1, 14) (1, 15) + COLON ':' (1, 15) (1, 16) + NAME 'pass' (1, 17) (1, 21) + """) + + self.check_tokenize('''\ +async def foo(): + def foo(await): + await = 1 + if 1: + await +async += 1 +''', """\ + NAME 'async' (1, 0) (1, 5) + NAME 'def' (1, 6) (1, 9) + NAME 'foo' (1, 10) (1, 13) + LPAR '(' (1, 13) (1, 14) + RPAR ')' (1, 14) (1, 15) + COLON ':' (1, 15) (1, 16) + NEWLINE '' (1, 16) (1, 16) + INDENT '' (2, -1) (2, -1) + NAME 'def' (2, 2) (2, 5) + NAME 'foo' (2, 6) (2, 9) + LPAR '(' (2, 9) (2, 10) + NAME 'await' (2, 10) (2, 15) + RPAR ')' (2, 15) (2, 16) + COLON ':' (2, 16) (2, 17) + NEWLINE '' (2, 17) (2, 17) + INDENT '' (3, -1) (3, -1) + NAME 'await' (3, 4) (3, 9) + EQUAL '=' (3, 10) (3, 11) + NUMBER '1' (3, 12) (3, 13) + NEWLINE '' (3, 13) (3, 13) + DEDENT '' (4, -1) (4, -1) + NAME 'if' (4, 2) (4, 4) + NUMBER '1' (4, 5) (4, 6) + COLON ':' (4, 6) (4, 7) + NEWLINE '' (4, 7) (4, 7) + INDENT '' (5, -1) (5, -1) + NAME 'await' (5, 4) (5, 9) + NEWLINE '' (5, 9) (5, 9) + DEDENT '' (6, -1) (6, -1) + DEDENT '' (6, -1) (6, -1) + NAME 'async' (6, 0) (6, 5) + PLUSEQUAL '+=' (6, 6) (6, 8) + NUMBER '1' (6, 9) (6, 10) + NEWLINE '' (6, 10) (6, 10) + """) + + self.check_tokenize('async def foo():\n async for i in 1: pass', """\ + NAME 'async' (1, 0) (1, 5) + NAME 'def' (1, 6) (1, 9) + NAME 'foo' (1, 10) (1, 13) + LPAR '(' (1, 13) (1, 14) + RPAR ')' (1, 14) (1, 15) + COLON ':' (1, 15) (1, 16) + NEWLINE '' (1, 16) (1, 16) + INDENT '' (2, -1) (2, -1) + NAME 'async' (2, 2) (2, 7) + NAME 'for' (2, 8) (2, 11) + NAME 'i' (2, 12) (2, 13) + NAME 'in' (2, 14) (2, 16) + NUMBER '1' (2, 17) (2, 18) + COLON ':' (2, 18) (2, 19) + NAME 'pass' (2, 20) (2, 24) + DEDENT '' (2, -1) (2, -1) + """) + + self.check_tokenize('async def foo(async): await', """\ + NAME 'async' (1, 0) (1, 5) + NAME 'def' (1, 6) (1, 9) + NAME 'foo' (1, 10) (1, 13) + LPAR '(' (1, 13) (1, 14) + NAME 'async' (1, 14) (1, 19) + RPAR ')' (1, 19) (1, 20) + COLON ':' (1, 20) (1, 21) + NAME 'await' (1, 22) (1, 27) + """) + + self.check_tokenize('''\ +def f(): + + def baz(): pass + async def bar(): pass + + await = 2''', """\ + NAME 'def' (1, 0) (1, 3) + NAME 'f' (1, 4) (1, 5) + LPAR '(' (1, 5) (1, 6) + RPAR ')' (1, 6) (1, 7) + COLON ':' (1, 7) (1, 8) + NEWLINE '' (1, 8) (1, 8) + INDENT '' (3, -1) (3, -1) + NAME 'def' (3, 2) (3, 5) + NAME 'baz' (3, 6) (3, 9) + LPAR '(' (3, 9) (3, 10) + RPAR ')' (3, 10) (3, 11) + COLON ':' (3, 11) (3, 12) + NAME 'pass' (3, 13) (3, 17) + NEWLINE '' (3, 17) (3, 17) + NAME 'async' (4, 2) (4, 7) + NAME 'def' (4, 8) (4, 11) + NAME 'bar' (4, 12) (4, 15) + LPAR '(' (4, 15) (4, 16) + RPAR ')' (4, 16) (4, 17) + COLON ':' (4, 17) (4, 18) + NAME 'pass' (4, 19) (4, 23) + NEWLINE '' (4, 23) (4, 23) + NAME 'await' (6, 2) (6, 7) + EQUAL '=' (6, 8) (6, 9) + NUMBER '2' (6, 10) (6, 11) + DEDENT '' (6, -1) (6, -1) + """) + + self.check_tokenize('''\ +async def f(): + + def baz(): pass + async def bar(): pass + + await = 2''', """\ + NAME 'async' (1, 0) (1, 5) + NAME 'def' (1, 6) (1, 9) + NAME 'f' (1, 10) (1, 11) + LPAR '(' (1, 11) (1, 12) + RPAR ')' (1, 12) (1, 13) + COLON ':' (1, 13) (1, 14) + NEWLINE '' (1, 14) (1, 14) + INDENT '' (3, -1) (3, -1) + NAME 'def' (3, 2) (3, 5) + NAME 'baz' (3, 6) (3, 9) + LPAR '(' (3, 9) (3, 10) + RPAR ')' (3, 10) (3, 11) + COLON ':' (3, 11) (3, 12) + NAME 'pass' (3, 13) (3, 17) + NEWLINE '' (3, 17) (3, 17) + NAME 'async' (4, 2) (4, 7) + NAME 'def' (4, 8) (4, 11) + NAME 'bar' (4, 12) (4, 15) + LPAR '(' (4, 15) (4, 16) + RPAR ')' (4, 16) (4, 17) + COLON ':' (4, 17) (4, 18) + NAME 'pass' (4, 19) (4, 23) + NEWLINE '' (4, 23) (4, 23) + NAME 'await' (6, 2) (6, 7) + EQUAL '=' (6, 8) (6, 9) + NUMBER '2' (6, 10) (6, 11) + DEDENT '' (6, -1) (6, -1) + """) + + def test_unicode(self): + + self.check_tokenize("Örter = u'places'\ngrün = U'green'", """\ + NAME 'Örter' (1, 0) (1, 5) + EQUAL '=' (1, 6) (1, 7) + STRING "u'places'" (1, 8) (1, 17) + NEWLINE '' (1, 17) (1, 17) + NAME 'grün' (2, 0) (2, 4) + EQUAL '=' (2, 5) (2, 6) + STRING "U'green'" (2, 7) (2, 15) + """) + + @unittest.expectedFailure # TODO: RUSTPYTHON + def test_invalid_syntax(self): + def get_tokens(string): + the_string = StringIO(string) + return list(tokenize._generate_tokens_from_c_tokenizer(the_string.readline)) + + for case in [ + "(1+2]", + "(1+2}", + "{1+2]", + "1_", + "1.2_", + "1e2_", + "1e+", + + "\xa0", + "€", + "0b12", + "0b1_2", + "0b2", + "0b1_", + "0b", + "0o18", + "0o1_8", + "0o8", + "0o1_", + "0o", + "0x1_", + "0x", + "1_", + "012", + "1.2_", + "1e2_", + "1e+", + "'sdfsdf", + "'''sdfsdf''", + "("*1000+"a"+")"*1000, + "]", + """\ + f'__{ + x:d + }__'""", + " a\n\x00", + ]: + with self.subTest(case=case): + self.assertRaises(tokenize.TokenError, get_tokens, case) + + @unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: IndentationError not raised by + @support.skip_wasi_stack_overflow() + def test_max_indent(self): + MAXINDENT = 100 + + def generate_source(indents): + source = ''.join((' ' * x) + 'if True:\n' for x in range(indents)) + source += ' ' * indents + 'pass\n' + return source + + valid = generate_source(MAXINDENT - 1) + the_input = StringIO(valid) + tokens = list(tokenize._generate_tokens_from_c_tokenizer(the_input.readline)) + self.assertEqual(tokens[-2].type, tokenize.DEDENT) + self.assertEqual(tokens[-1].type, tokenize.ENDMARKER) + compile(valid, "", "exec") + + invalid = generate_source(MAXINDENT) + the_input = StringIO(invalid) + self.assertRaises(IndentationError, lambda: list(tokenize._generate_tokens_from_c_tokenizer(the_input.readline))) + self.assertRaises( + IndentationError, compile, invalid, "", "exec" + ) + + @unittest.expectedFailure # TODO: RUSTPYTHON; (0, '')] + def test_continuation_lines_indentation(self): + def get_tokens(string): + the_string = StringIO(string) + return [(kind, string) for (kind, string, *_) + in tokenize._generate_tokens_from_c_tokenizer(the_string.readline)] + + code = dedent(""" + def fib(n): + \\ + '''Print a Fibonacci series up to n.''' + \\ + a, b = 0, 1 + """) + + self.check_tokenize(code, """\ + NAME 'def' (2, 0) (2, 3) + NAME 'fib' (2, 4) (2, 7) + LPAR '(' (2, 7) (2, 8) + NAME 'n' (2, 8) (2, 9) + RPAR ')' (2, 9) (2, 10) + COLON ':' (2, 10) (2, 11) + NEWLINE '' (2, 11) (2, 11) + INDENT '' (4, -1) (4, -1) + STRING "'''Print a Fibonacci series up to n.'''" (4, 0) (4, 39) + NEWLINE '' (4, 39) (4, 39) + NAME 'a' (6, 0) (6, 1) + COMMA ',' (6, 1) (6, 2) + NAME 'b' (6, 3) (6, 4) + EQUAL '=' (6, 5) (6, 6) + NUMBER '0' (6, 7) (6, 8) + COMMA ',' (6, 8) (6, 9) + NUMBER '1' (6, 10) (6, 11) + NEWLINE '' (6, 11) (6, 11) + DEDENT '' (6, -1) (6, -1) + """) + + code_no_cont = dedent(""" + def fib(n): + '''Print a Fibonacci series up to n.''' + a, b = 0, 1 + """) + + self.assertEqual(get_tokens(code), get_tokens(code_no_cont)) + + code = dedent(""" + pass + \\ + + pass + """) + + self.check_tokenize(code, """\ + NAME 'pass' (2, 0) (2, 4) + NEWLINE '' (2, 4) (2, 4) + NAME 'pass' (5, 0) (5, 4) + NEWLINE '' (5, 4) (5, 4) + """) + + code_no_cont = dedent(""" + pass + pass + """) + + self.assertEqual(get_tokens(code), get_tokens(code_no_cont)) + + code = dedent(""" + if x: + y = 1 + \\ + \\ + \\ + \\ + foo = 1 + """) + + self.check_tokenize(code, """\ + NAME 'if' (2, 0) (2, 2) + NAME 'x' (2, 3) (2, 4) + COLON ':' (2, 4) (2, 5) + NEWLINE '' (2, 5) (2, 5) + INDENT '' (3, -1) (3, -1) + NAME 'y' (3, 4) (3, 5) + EQUAL '=' (3, 6) (3, 7) + NUMBER '1' (3, 8) (3, 9) + NEWLINE '' (3, 9) (3, 9) + NAME 'foo' (8, 4) (8, 7) + EQUAL '=' (8, 8) (8, 9) + NUMBER '1' (8, 10) (8, 11) + NEWLINE '' (8, 11) (8, 11) + DEDENT '' (8, -1) (8, -1) + """) + + code_no_cont = dedent(""" + if x: + y = 1 + foo = 1 + """) + + self.assertEqual(get_tokens(code), get_tokens(code_no_cont)) + + +class CTokenizerBufferTests(unittest.TestCase): + def test_newline_at_the_end_of_buffer(self): + # See issue 99581: Make sure that if we need to add a new line at the + # end of the buffer, we have enough space in the buffer, specially when + # the current line is as long as the buffer space available. + test_script = f"""\ + #coding: latin-1 + #{"a"*10000} + #{"a"*10002}""" + with os_helper.temp_dir() as temp_dir: + file_name = make_script(temp_dir, 'foo', test_script) + run_test_script(file_name) + + +class CommandLineTest(unittest.TestCase): + def setUp(self): + self.filename = tempfile.mktemp() + self.addCleanup(os_helper.unlink, self.filename) + + @staticmethod + def text_normalize(string): + """Dedent *string* and strip it from its surrounding whitespaces. + + This method is used by the other utility functions so that any + string to write or to match against can be freely indented. + """ + return re.sub(r'\s+', ' ', string).strip() + + def set_source(self, content): + with open(self.filename, 'w') as fp: + fp.write(content) + + def invoke_tokenize(self, *flags): + output = StringIO() + with contextlib.redirect_stdout(output): + tokenize._main(args=[*flags, self.filename]) + return self.text_normalize(output.getvalue()) + + def check_output(self, source, expect, *flags): + with self.subTest(source=source, flags=flags): + self.set_source(source) + res = self.invoke_tokenize(*flags) + expect = self.text_normalize(expect) + self.assertListEqual(res.splitlines(), expect.splitlines()) + + def test_invocation(self): + # test various combinations of parameters + base_flags = ('-e', '--exact') + + self.set_source(''' + def f(): + print(x) + return None + ''') + + for flag in base_flags: + with self.subTest(args=flag): + _ = self.invoke_tokenize(flag) + + with self.assertRaises(SystemExit): + # suppress argparse error message + with contextlib.redirect_stderr(StringIO()): + _ = self.invoke_tokenize('--unknown') + + def test_without_flag(self): + # test 'python -m tokenize source.py' + source = 'a = 1' + expect = ''' + 0,0-0,0: ENCODING 'utf-8' + 1,0-1,1: NAME 'a' + 1,2-1,3: OP '=' + 1,4-1,5: NUMBER '1' + 1,5-1,6: NEWLINE '' + 2,0-2,0: ENDMARKER '' + ''' + self.check_output(source, expect) + + def test_exact_flag(self): + # test 'python -m tokenize -e/--exact source.py' + source = 'a = 1' + expect = ''' + 0,0-0,0: ENCODING 'utf-8' + 1,0-1,1: NAME 'a' + 1,2-1,3: EQUAL '=' + 1,4-1,5: NUMBER '1' + 1,5-1,6: NEWLINE '' + 2,0-2,0: ENDMARKER '' + ''' + for flag in ['-e', '--exact']: + self.check_output(source, expect, flag) + + +class StringPrefixTest(unittest.TestCase): + @staticmethod + def determine_valid_prefixes(): + # Try all lengths until we find a length that has zero valid + # prefixes. This will miss the case where for example there + # are no valid 3 character prefixes, but there are valid 4 + # character prefixes. That seems unlikely. + + single_char_valid_prefixes = set() + + # Find all of the single character string prefixes. Just get + # the lowercase version, we'll deal with combinations of upper + # and lower case later. I'm using this logic just in case + # some uppercase-only prefix is added. + for letter in itertools.chain(string.ascii_lowercase, string.ascii_uppercase): + try: + eval(f'{letter}""') + single_char_valid_prefixes.add(letter.lower()) + except SyntaxError: + pass + + # This logic assumes that all combinations of valid prefixes only use + # the characters that are valid single character prefixes. That seems + # like a valid assumption, but if it ever changes this will need + # adjusting. + valid_prefixes = set() + for length in itertools.count(): + num_at_this_length = 0 + for prefix in ( + "".join(l) + for l in itertools.combinations(single_char_valid_prefixes, length) + ): + for t in itertools.permutations(prefix): + for u in itertools.product(*[(c, c.upper()) for c in t]): + p = "".join(u) + if p == "not": + # 'not' can never be a string prefix, + # because it's a valid expression: not "" + continue + try: + eval(f'{p}""') + + # No syntax error, so p is a valid string + # prefix. + + valid_prefixes.add(p) + num_at_this_length += 1 + except SyntaxError: + pass + if num_at_this_length == 0: + return valid_prefixes + + + def test_prefixes(self): + # Get the list of defined string prefixes. I don't see an + # obvious documented way of doing this, but probably the best + # thing is to split apart tokenize.StringPrefix. + + # Make sure StringPrefix begins and ends in parens. We're + # assuming it's of the form "(a|b|ab)", if a, b, and cd are + # valid string prefixes. + self.assertEqual(tokenize.StringPrefix[0], '(') + self.assertEqual(tokenize.StringPrefix[-1], ')') + + # Then split apart everything else by '|'. + defined_prefixes = set(tokenize.StringPrefix[1:-1].split('|')) + + # Now compute the actual allowed string prefixes and compare + # to what is defined in the tokenize module. + self.assertEqual(defined_prefixes, self.determine_valid_prefixes()) + + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index 2c03781bc72..5042b3c17b0 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -431,7 +431,6 @@ def test(i, format_spec, result): test(123456, "1=20", '11111111111111123456') test(123456, "*=20", '**************123456') - @unittest.expectedFailure # TODO: RUSTPYTHON; + 1234.57 @run_with_locale('LC_NUMERIC', 'en_US.UTF8', '') def test_float__format__locale(self): # test locale support for __format__ code 'n' @@ -441,7 +440,6 @@ def test_float__format__locale(self): self.assertEqual(locale.format_string('%g', x, grouping=True), format(x, 'n')) self.assertEqual(locale.format_string('%.10g', x, grouping=True), format(x, '.10n')) - @unittest.expectedFailure # TODO: RUSTPYTHON; + 123456789012345678901234567890 @run_with_locale('LC_NUMERIC', 'en_US.UTF8', '') def test_int__format__locale(self): # test locale support for __format__ code 'n' for integers diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py index e04afbb1af5..27bbcf7048d 100644 --- a/Lib/test/test_weakref.py +++ b/Lib/test/test_weakref.py @@ -1862,7 +1862,6 @@ def test_weak_valued_delitem(self): self.assertEqual(len(d), 1) self.assertEqual(list(d.items()), [('something else', o2)]) - @unittest.expectedFailure # TODO: RUSTPYTHON def test_weak_keyed_bad_delitem(self): d = weakref.WeakKeyDictionary() o = Object('1') diff --git a/Lib/tokenize.py b/Lib/tokenize.py index d72968e4250..1f31258ce36 100644 --- a/Lib/tokenize.py +++ b/Lib/tokenize.py @@ -24,10 +24,7 @@ __credits__ = ('GvR, ESR, Tim Peters, Thomas Wouters, Fred Drake, ' 'Skip Montanaro, Raymond Hettinger, Trent Nelson, ' 'Michael Foord') -try: - from builtins import open as _builtin_open -except ImportError: - pass +from builtins import open as _builtin_open from codecs import lookup, BOM_UTF8 import collections import functools @@ -37,13 +34,14 @@ import sys from token import * from token import EXACT_TOKEN_TYPES +import _tokenize -cookie_re = re.compile(r'^[ \t\f]*#.*?coding[:=][ \t]*([-\w.]+)', re.ASCII) +cookie_re = re.compile(br'^[ \t\f]*#.*?coding[:=][ \t]*([-\w.]+)', re.ASCII) blank_re = re.compile(br'^[ \t\f]*(?:[#\r\n]|$)', re.ASCII) import token __all__ = token.__all__ + ["tokenize", "generate_tokens", "detect_encoding", - "untokenize", "TokenInfo"] + "untokenize", "TokenInfo", "open", "TokenError"] del token class TokenInfo(collections.namedtuple('TokenInfo', 'type string start end line')): @@ -88,7 +86,7 @@ def _all_string_prefixes(): # The valid string prefixes. Only contain the lower case versions, # and don't contain any permutations (include 'fr', but not # 'rf'). The various permutations will be generated. - _valid_string_prefixes = ['b', 'r', 'u', 'f', 'br', 'fr'] + _valid_string_prefixes = ['b', 'r', 'u', 'f', 't', 'br', 'fr', 'tr'] # if we add binary f-strings, add: ['fb', 'fbr'] result = {''} for prefix in _valid_string_prefixes: @@ -134,7 +132,7 @@ def _compile(expr): group("'", r'\\\r?\n'), StringPrefix + r'"[^\n"\\]*(?:\\.[^\n"\\]*)*' + group('"', r'\\\r?\n')) -PseudoExtras = group(r'\\\r?\n|\Z', Comment, Triple) +PseudoExtras = group(r'\\\r?\n|\z', Comment, Triple) PseudoToken = Whitespace + group(PseudoExtras, Number, Funny, ContStr, Name) # For a given string prefix plus quotes, endpats maps it to a regex @@ -146,6 +144,7 @@ def _compile(expr): endpats[_prefix + '"'] = Double endpats[_prefix + "'''"] = Single3 endpats[_prefix + '"""'] = Double3 +del _prefix # A set of all of the single and triple quoted string prefixes, # including the opening quotes. @@ -156,13 +155,12 @@ def _compile(expr): single_quoted.add(u) for u in (t + '"""', t + "'''"): triple_quoted.add(u) +del t, u tabsize = 8 class TokenError(Exception): pass -class StopTokenizing(Exception): pass - class Untokenizer: @@ -170,6 +168,8 @@ def __init__(self): self.tokens = [] self.prev_row = 1 self.prev_col = 0 + self.prev_type = None + self.prev_line = "" self.encoding = None def add_whitespace(self, start): @@ -177,14 +177,51 @@ def add_whitespace(self, start): if row < self.prev_row or row == self.prev_row and col < self.prev_col: raise ValueError("start ({},{}) precedes previous end ({},{})" .format(row, col, self.prev_row, self.prev_col)) - row_offset = row - self.prev_row - if row_offset: - self.tokens.append("\\\n" * row_offset) - self.prev_col = 0 + self.add_backslash_continuation(start) col_offset = col - self.prev_col if col_offset: self.tokens.append(" " * col_offset) + def add_backslash_continuation(self, start): + """Add backslash continuation characters if the row has increased + without encountering a newline token. + + This also inserts the correct amount of whitespace before the backslash. + """ + row = start[0] + row_offset = row - self.prev_row + if row_offset == 0: + return + + newline = '\r\n' if self.prev_line.endswith('\r\n') else '\n' + line = self.prev_line.rstrip('\\\r\n') + ws = ''.join(_itertools.takewhile(str.isspace, reversed(line))) + self.tokens.append(ws + f"\\{newline}" * row_offset) + self.prev_col = 0 + + def escape_brackets(self, token): + characters = [] + consume_until_next_bracket = False + for character in token: + if character == "}": + if consume_until_next_bracket: + consume_until_next_bracket = False + else: + characters.append(character) + if character == "{": + n_backslashes = sum( + 1 for char in _itertools.takewhile( + "\\".__eq__, + characters[-2::-1] + ) + ) + if n_backslashes % 2 == 0 or characters[-1] != "N": + characters.append(character) + else: + consume_until_next_bracket = True + characters.append(character) + return "".join(characters) + def untokenize(self, iterable): it = iter(iterable) indents = [] @@ -214,12 +251,22 @@ def untokenize(self, iterable): self.tokens.append(indent) self.prev_col = len(indent) startline = False + elif tok_type in {FSTRING_MIDDLE, TSTRING_MIDDLE}: + if '{' in token or '}' in token: + token = self.escape_brackets(token) + last_line = token.splitlines()[-1] + end_line, end_col = end + extra_chars = last_line.count("{{") + last_line.count("}}") + end = (end_line, end_col + extra_chars) + self.add_whitespace(start) self.tokens.append(token) self.prev_row, self.prev_col = end if tok_type in (NEWLINE, NL): self.prev_row += 1 self.prev_col = 0 + self.prev_type = tok_type + self.prev_line = line return "".join(self.tokens) def compat(self, token, iterable): @@ -227,6 +274,7 @@ def compat(self, token, iterable): toks_append = self.tokens.append startline = token[0] in (NEWLINE, NL) prevstring = False + in_fstring_or_tstring = 0 for tok in _itertools.chain([token], iterable): toknum, tokval = tok[:2] @@ -245,6 +293,10 @@ def compat(self, token, iterable): else: prevstring = False + if toknum in {FSTRING_START, TSTRING_START}: + in_fstring_or_tstring += 1 + elif toknum in {FSTRING_END, TSTRING_END}: + in_fstring_or_tstring -= 1 if toknum == INDENT: indents.append(tokval) continue @@ -256,7 +308,19 @@ def compat(self, token, iterable): elif startline and indents: toks_append(indents[-1]) startline = False + elif toknum in {FSTRING_MIDDLE, TSTRING_MIDDLE}: + tokval = self.escape_brackets(tokval) + + # Insert a space between two consecutive brackets if we are in an f-string or t-string + if tokval in {"{", "}"} and self.tokens and self.tokens[-1] == tokval and in_fstring_or_tstring: + tokval = ' ' + tokval + + # Insert a space between two consecutive f-strings + if toknum in (STRING, FSTRING_START) and self.prev_type in (STRING, FSTRING_END): + self.tokens.append(" ") + toks_append(tokval) + self.prev_type = toknum def untokenize(iterable): @@ -268,16 +332,10 @@ def untokenize(iterable): with at least two elements, a token number and token value. If only two tokens are passed, the resulting output is poor. - Round-trip invariant for full input: - Untokenized source will match input source exactly - - Round-trip invariant for limited input: - # Output bytes will tokenize back to the input - t1 = [tok[:2] for tok in tokenize(f.readline)] - newcode = untokenize(t1) - readline = BytesIO(newcode).readline - t2 = [tok[:2] for tok in tokenize(readline)] - assert t1 == t2 + The result is guaranteed to tokenize back to match the input so + that the conversion is lossless and round-trips are assured. + The guarantee applies only to the token type and token string as + the spacing between tokens (column positions) may change. """ ut = Untokenizer() out = ut.untokenize(iterable) @@ -287,7 +345,7 @@ def untokenize(iterable): def _get_normal_name(orig_enc): - """Imitates get_normal_name in tokenizer.c.""" + """Imitates get_normal_name in Parser/tokenizer/helpers.c.""" # Only care about the first 12 characters. enc = orig_enc[:12].lower().replace("_", "-") if enc == "utf-8" or enc.startswith("utf-8-"): @@ -327,22 +385,23 @@ def read_or_stop(): except StopIteration: return b'' - def find_cookie(line): + def check(line, encoding): + # Check if the line matches the encoding. + if 0 in line: + raise SyntaxError("source code cannot contain null bytes") try: - # Decode as UTF-8. Either the line is an encoding declaration, - # in which case it should be pure ASCII, or it must be UTF-8 - # per default encoding. - line_string = line.decode('utf-8') + line.decode(encoding) except UnicodeDecodeError: msg = "invalid or missing encoding declaration" if filename is not None: msg = '{} for {!r}'.format(msg, filename) raise SyntaxError(msg) - match = cookie_re.match(line_string) + def find_cookie(line): + match = cookie_re.match(line) if not match: return None - encoding = _get_normal_name(match.group(1)) + encoding = _get_normal_name(match.group(1).decode()) try: codec = lookup(encoding) except LookupError: @@ -375,18 +434,23 @@ def find_cookie(line): encoding = find_cookie(first) if encoding: + check(first, encoding) return encoding, [first] if not blank_re.match(first): + check(first, default) return default, [first] second = read_or_stop() if not second: + check(first, default) return default, [first] encoding = find_cookie(second) if encoding: + check(first + second, encoding) return encoding, [first, second] + check(first + second, default) return default, [first, second] @@ -405,7 +469,6 @@ def open(filename): buffer.close() raise - def tokenize(readline): """ The tokenize() generator requires one argument, readline, which @@ -426,193 +489,13 @@ def tokenize(readline): which tells you which encoding was used to decode the bytes stream. """ encoding, consumed = detect_encoding(readline) - empty = _itertools.repeat(b"") - rl_gen = _itertools.chain(consumed, iter(readline, b""), empty) - return _tokenize(rl_gen.__next__, encoding) - - -def _tokenize(readline, encoding): - lnum = parenlev = continued = 0 - numchars = '0123456789' - contstr, needcont = '', 0 - contline = None - indents = [0] - + rl_gen = _itertools.chain(consumed, iter(readline, b"")) if encoding is not None: if encoding == "utf-8-sig": # BOM will already have been stripped. encoding = "utf-8" yield TokenInfo(ENCODING, encoding, (0, 0), (0, 0), '') - last_line = b'' - line = b'' - while True: # loop over lines in stream - try: - # We capture the value of the line variable here because - # readline uses the empty string '' to signal end of input, - # hence `line` itself will always be overwritten at the end - # of this loop. - last_line = line - line = readline() - except StopIteration: - line = b'' - - if encoding is not None: - line = line.decode(encoding) - lnum += 1 - pos, max = 0, len(line) - - if contstr: # continued string - if not line: - raise TokenError("EOF in multi-line string", strstart) - endmatch = endprog.match(line) - if endmatch: - pos = end = endmatch.end(0) - yield TokenInfo(STRING, contstr + line[:end], - strstart, (lnum, end), contline + line) - contstr, needcont = '', 0 - contline = None - elif needcont and line[-2:] != '\\\n' and line[-3:] != '\\\r\n': - yield TokenInfo(ERRORTOKEN, contstr + line, - strstart, (lnum, len(line)), contline) - contstr = '' - contline = None - continue - else: - contstr = contstr + line - contline = contline + line - continue - - elif parenlev == 0 and not continued: # new statement - if not line: break - column = 0 - while pos < max: # measure leading whitespace - if line[pos] == ' ': - column += 1 - elif line[pos] == '\t': - column = (column//tabsize + 1)*tabsize - elif line[pos] == '\f': - column = 0 - else: - break - pos += 1 - if pos == max: - break - - if line[pos] in '#\r\n': # skip comments or blank lines - if line[pos] == '#': - comment_token = line[pos:].rstrip('\r\n') - yield TokenInfo(COMMENT, comment_token, - (lnum, pos), (lnum, pos + len(comment_token)), line) - pos += len(comment_token) - - yield TokenInfo(NL, line[pos:], - (lnum, pos), (lnum, len(line)), line) - continue - - if column > indents[-1]: # count indents or dedents - indents.append(column) - yield TokenInfo(INDENT, line[:pos], (lnum, 0), (lnum, pos), line) - while column < indents[-1]: - if column not in indents: - raise IndentationError( - "unindent does not match any outer indentation level", - ("", lnum, pos, line)) - indents = indents[:-1] - - yield TokenInfo(DEDENT, '', (lnum, pos), (lnum, pos), line) - - else: # continued statement - if not line: - raise TokenError("EOF in multi-line statement", (lnum, 0)) - continued = 0 - - while pos < max: - pseudomatch = _compile(PseudoToken).match(line, pos) - if pseudomatch: # scan for tokens - start, end = pseudomatch.span(1) - spos, epos, pos = (lnum, start), (lnum, end), end - if start == end: - continue - token, initial = line[start:end], line[start] - - if (initial in numchars or # ordinary number - (initial == '.' and token != '.' and token != '...')): - yield TokenInfo(NUMBER, token, spos, epos, line) - elif initial in '\r\n': - if parenlev > 0: - yield TokenInfo(NL, token, spos, epos, line) - else: - yield TokenInfo(NEWLINE, token, spos, epos, line) - - elif initial == '#': - assert not token.endswith("\n") - yield TokenInfo(COMMENT, token, spos, epos, line) - - elif token in triple_quoted: - endprog = _compile(endpats[token]) - endmatch = endprog.match(line, pos) - if endmatch: # all on one line - pos = endmatch.end(0) - token = line[start:pos] - yield TokenInfo(STRING, token, spos, (lnum, pos), line) - else: - strstart = (lnum, start) # multiple lines - contstr = line[start:] - contline = line - break - - # Check up to the first 3 chars of the token to see if - # they're in the single_quoted set. If so, they start - # a string. - # We're using the first 3, because we're looking for - # "rb'" (for example) at the start of the token. If - # we switch to longer prefixes, this needs to be - # adjusted. - # Note that initial == token[:1]. - # Also note that single quote checking must come after - # triple quote checking (above). - elif (initial in single_quoted or - token[:2] in single_quoted or - token[:3] in single_quoted): - if token[-1] == '\n': # continued string - strstart = (lnum, start) - # Again, using the first 3 chars of the - # token. This is looking for the matching end - # regex for the correct type of quote - # character. So it's really looking for - # endpats["'"] or endpats['"'], by trying to - # skip string prefix characters, if any. - endprog = _compile(endpats.get(initial) or - endpats.get(token[1]) or - endpats.get(token[2])) - contstr, needcont = line[start:], 1 - contline = line - break - else: # ordinary string - yield TokenInfo(STRING, token, spos, epos, line) - - elif initial.isidentifier(): # ordinary name - yield TokenInfo(NAME, token, spos, epos, line) - elif initial == '\\': # continued stmt - continued = 1 - else: - if initial in '([{': - parenlev += 1 - elif initial in ')]}': - parenlev -= 1 - yield TokenInfo(OP, token, spos, epos, line) - else: - yield TokenInfo(ERRORTOKEN, line[pos], - (lnum, pos), (lnum, pos+1), line) - pos += 1 - - # Add an implicit NEWLINE if the input doesn't end in one - if last_line and last_line[-1] not in '\r\n' and not last_line.strip().startswith("#"): - yield TokenInfo(NEWLINE, '', (lnum - 1, len(last_line)), (lnum - 1, len(last_line) + 1), '') - for indent in indents[1:]: # pop remaining indent levels - yield TokenInfo(DEDENT, '', (lnum, 0), (lnum, 0), '') - yield TokenInfo(ENDMARKER, '', (lnum, 0), (lnum, 0), '') - + yield from _generate_tokens_from_c_tokenizer(rl_gen.__next__, encoding, extra_tokens=True) def generate_tokens(readline): """Tokenize a source reading Python code as unicode strings. @@ -620,9 +503,9 @@ def generate_tokens(readline): This has the same API as tokenize(), except that it expects the *readline* callable to return str objects instead of bytes. """ - return _tokenize(readline, None) + return _generate_tokens_from_c_tokenizer(readline, extra_tokens=True) -def main(): +def _main(args=None): import argparse # Helper error handling routines @@ -641,13 +524,13 @@ def error(message, filename=None, location=None): sys.exit(1) # Parse the arguments and options - parser = argparse.ArgumentParser(prog='python -m tokenize') + parser = argparse.ArgumentParser(color=True) parser.add_argument(dest='filename', nargs='?', metavar='filename.py', help='the file to tokenize; defaults to stdin') parser.add_argument('-e', '--exact', dest='exact', action='store_true', help='display token names using the exact type') - args = parser.parse_args() + args = parser.parse_args(args) try: # Tokenize the input @@ -657,7 +540,9 @@ def error(message, filename=None, location=None): tokens = list(tokenize(f.readline)) else: filename = "" - tokens = _tokenize(sys.stdin.readline, None) + tokens = _generate_tokens_from_c_tokenizer( + sys.stdin.readline, extra_tokens=True) + # Output the tokenization for token in tokens: @@ -683,5 +568,31 @@ def error(message, filename=None, location=None): perror("unexpected error: %s" % err) raise +def _transform_msg(msg): + """Transform error messages from the C tokenizer into the Python tokenize + + The C tokenizer is more picky than the Python one, so we need to massage + the error messages a bit for backwards compatibility. + """ + if "unterminated triple-quoted string literal" in msg: + return "EOF in multi-line string" + return msg + +def _generate_tokens_from_c_tokenizer(source, encoding=None, extra_tokens=False): + """Tokenize a source reading Python code as unicode strings using the internal C tokenizer""" + if encoding is None: + it = _tokenize.TokenizerIter(source, extra_tokens=extra_tokens) + else: + it = _tokenize.TokenizerIter(source, encoding=encoding, extra_tokens=extra_tokens) + try: + for info in it: + yield TokenInfo._make(info) + except SyntaxError as e: + if type(e) != SyntaxError: + raise e from None + msg = _transform_msg(e.msg) + raise TokenError(msg, (e.lineno, e.offset)) from None + + if __name__ == "__main__": - main() + _main() diff --git a/crates/codegen/src/compile.rs b/crates/codegen/src/compile.rs index 6f7d8c15236..94c6642aac2 100644 --- a/crates/codegen/src/compile.rs +++ b/crates/codegen/src/compile.rs @@ -29,7 +29,7 @@ use rustpython_compiler_core::{ self, AnyInstruction, Arg as OpArgMarker, BinaryOperator, BuildSliceArgCount, CodeObject, ComparisonOperator, ConstantData, ConvertValueOparg, Instruction, IntrinsicFunction1, Invert, LoadAttr, LoadSuperAttr, OpArg, OpArgType, PseudoInstruction, SpecialMethod, - UnpackExArgs, + UnpackExArgs, oparg, }, }; use rustpython_wtf8::Wtf8Buf; @@ -715,14 +715,14 @@ impl Compiler { } /// Get the index of a local variable. - fn get_local_var_index(&mut self, name: &str) -> CompileResult { + fn get_local_var_index(&mut self, name: &str) -> CompileResult { let info = self.code_stack.last_mut().unwrap(); let idx = info .metadata .varnames .get_index_of(name) .unwrap_or_else(|| info.metadata.varnames.insert_full(name.to_owned()).0); - Ok(idx.to_u32()) + Ok(idx.to_u32().into()) } /// Get the index of a global name. @@ -1283,7 +1283,12 @@ impl Compiler { /// if format > VALUE_WITH_FAKE_GLOBALS (2): raise NotImplementedError fn emit_format_validation(&mut self) -> CompileResult<()> { // Load format parameter (first local variable, index 0) - emit!(self, Instruction::LoadFast { var_num: 0 }); + emit!( + self, + Instruction::LoadFast { + var_num: oparg::VarNum::from_u32(0) + } + ); // Load VALUE_WITH_FAKE_GLOBALS constant (2) self.emit_load_const(ConstantData::Integer { value: 2.into() }); @@ -1562,15 +1567,19 @@ impl Compiler { fn name(&mut self, name: &str) -> bytecode::NameIdx { self._name_inner(name, |i| &mut i.metadata.names) } - fn varname(&mut self, name: &str) -> CompileResult { + + fn varname(&mut self, name: &str) -> CompileResult { // Note: __debug__ checks are now handled in symboltable phase - Ok(self._name_inner(name, |i| &mut i.metadata.varnames)) + Ok(oparg::VarNum::from_u32( + self._name_inner(name, |i| &mut i.metadata.varnames), + )) } + fn _name_inner( &mut self, name: &str, cache: impl FnOnce(&mut ir::CodeInfo) -> &mut IndexSet, - ) -> bytecode::NameIdx { + ) -> u32 { let name = self.mangle(name); let cache = cache(self.current_code_info()); cache @@ -2500,7 +2509,7 @@ impl Compiler { self.compile_expression(value)?; emit!(self, Instruction::ReturnValue); let value_code = self.exit_scope(); - self.make_closure(value_code, bytecode::MakeFunctionFlags::empty())?; + self.make_closure(value_code, bytecode::MakeFunctionFlags::new())?; // Stack: [type_params_tuple, value_closure] // Swap so unpack_sequence reverse gives correct order @@ -2513,7 +2522,7 @@ impl Compiler { let code = self.exit_scope(); self.ctx = prev_ctx; - self.make_closure(code, bytecode::MakeFunctionFlags::empty())?; + self.make_closure(code, bytecode::MakeFunctionFlags::new())?; emit!(self, Instruction::PushNull); emit!(self, Instruction::Call { argc: 0 }); @@ -2552,7 +2561,7 @@ impl Compiler { let code = self.exit_scope(); self.ctx = prev_ctx; - self.make_closure(code, bytecode::MakeFunctionFlags::empty())?; + self.make_closure(code, bytecode::MakeFunctionFlags::new())?; // Stack: [name, None, closure] } @@ -2716,7 +2725,7 @@ impl Compiler { self.ctx = prev_ctx; // Create closure for lazy evaluation - self.make_closure(code, bytecode::MakeFunctionFlags::empty())?; + self.make_closure(code, bytecode::MakeFunctionFlags::new())?; Ok(()) } @@ -3640,7 +3649,7 @@ impl Compiler { &mut self, parameters: &ast::Parameters, ) -> CompileResult { - let mut funcflags = bytecode::MakeFunctionFlags::empty(); + let mut funcflags = bytecode::MakeFunctionFlags::new(); // Handle positional defaults let defaults: Vec<_> = core::iter::empty() @@ -3660,7 +3669,7 @@ impl Compiler { count: defaults.len().to_u32() } ); - funcflags |= bytecode::MakeFunctionFlags::DEFAULTS; + funcflags.insert(bytecode::MakeFunctionFlag::Defaults); } // Handle keyword-only defaults @@ -3685,7 +3694,7 @@ impl Compiler { count: kw_with_defaults.len().to_u32(), } ); - funcflags |= bytecode::MakeFunctionFlags::KW_ONLY_DEFAULTS; + funcflags.insert(bytecode::MakeFunctionFlag::KwOnlyDefaults); } Ok(funcflags) @@ -3835,7 +3844,7 @@ impl Compiler { let annotate_code = self.exit_annotation_scope(saved_ctx); // Make a closure from the code object - self.make_closure(annotate_code, bytecode::MakeFunctionFlags::empty())?; + self.make_closure(annotate_code, bytecode::MakeFunctionFlags::new())?; Ok(true) } @@ -4045,7 +4054,7 @@ impl Compiler { ); // Make a closure from the code object - self.make_closure(annotate_code, bytecode::MakeFunctionFlags::empty())?; + self.make_closure(annotate_code, bytecode::MakeFunctionFlags::new())?; // Store as __annotate_func__ for classes, __annotate__ for modules let name = if parent_scope_type == CompilerScope::Class { @@ -4083,10 +4092,10 @@ impl Compiler { if is_generic { // Count args to pass to type params scope - if funcflags.contains(bytecode::MakeFunctionFlags::DEFAULTS) { + if funcflags.contains(&bytecode::MakeFunctionFlag::Defaults) { num_typeparam_args += 1; } - if funcflags.contains(bytecode::MakeFunctionFlags::KW_ONLY_DEFAULTS) { + if funcflags.contains(&bytecode::MakeFunctionFlag::KwOnlyDefaults) { num_typeparam_args += 1; } @@ -4111,13 +4120,13 @@ impl Compiler { // Add parameter names to varnames for the type params scope // These will be passed as arguments when the closure is called let current_info = self.current_code_info(); - if funcflags.contains(bytecode::MakeFunctionFlags::DEFAULTS) { + if funcflags.contains(&bytecode::MakeFunctionFlag::Defaults) { current_info .metadata .varnames .insert(".defaults".to_owned()); } - if funcflags.contains(bytecode::MakeFunctionFlags::KW_ONLY_DEFAULTS) { + if funcflags.contains(&bytecode::MakeFunctionFlag::KwOnlyDefaults) { current_info .metadata .varnames @@ -4129,16 +4138,16 @@ impl Compiler { // Load defaults/kwdefaults with LOAD_FAST for i in 0..num_typeparam_args { - emit!(self, Instruction::LoadFast { var_num: i as u32 }); + let var_num = oparg::VarNum::from(i as u32); + emit!(self, Instruction::LoadFast { var_num }); } } // Compile annotations as closure (PEP 649) - let annotations_flag = if self.compile_annotations_closure(name, parameters, returns)? { - bytecode::MakeFunctionFlags::ANNOTATE - } else { - bytecode::MakeFunctionFlags::empty() - }; + let mut annotations_flag = bytecode::MakeFunctionFlags::new(); + if self.compile_annotations_closure(name, parameters, returns)? { + annotations_flag.insert(bytecode::MakeFunctionFlag::Annotate); + } // Compile function body let final_funcflags = funcflags | annotations_flag; @@ -4169,7 +4178,7 @@ impl Compiler { self.ctx = saved_ctx; // Make closure for type params code - self.make_closure(type_params_code, bytecode::MakeFunctionFlags::empty())?; + self.make_closure(type_params_code, bytecode::MakeFunctionFlags::new())?; // Call the type params closure with defaults/kwdefaults as arguments. // Call protocol: [callable, self_or_null, arg1, ..., argN] @@ -4337,57 +4346,57 @@ impl Compiler { emit!( self, Instruction::SetFunctionAttribute { - flag: bytecode::MakeFunctionFlags::CLOSURE + flag: bytecode::MakeFunctionFlag::Closure } ); } // Set annotations if present - if flags.contains(bytecode::MakeFunctionFlags::ANNOTATIONS) { + if flags.contains(&bytecode::MakeFunctionFlag::Annotations) { emit!( self, Instruction::SetFunctionAttribute { - flag: bytecode::MakeFunctionFlags::ANNOTATIONS + flag: bytecode::MakeFunctionFlag::Annotations } ); } // Set __annotate__ closure if present (PEP 649) - if flags.contains(bytecode::MakeFunctionFlags::ANNOTATE) { + if flags.contains(&bytecode::MakeFunctionFlag::Annotate) { emit!( self, Instruction::SetFunctionAttribute { - flag: bytecode::MakeFunctionFlags::ANNOTATE + flag: bytecode::MakeFunctionFlag::Annotate } ); } // Set kwdefaults if present - if flags.contains(bytecode::MakeFunctionFlags::KW_ONLY_DEFAULTS) { + if flags.contains(&bytecode::MakeFunctionFlag::KwOnlyDefaults) { emit!( self, Instruction::SetFunctionAttribute { - flag: bytecode::MakeFunctionFlags::KW_ONLY_DEFAULTS + flag: bytecode::MakeFunctionFlag::KwOnlyDefaults } ); } // Set defaults if present - if flags.contains(bytecode::MakeFunctionFlags::DEFAULTS) { + if flags.contains(&bytecode::MakeFunctionFlag::Defaults) { emit!( self, Instruction::SetFunctionAttribute { - flag: bytecode::MakeFunctionFlags::DEFAULTS + flag: bytecode::MakeFunctionFlag::Defaults } ); } // Set type_params if present - if flags.contains(bytecode::MakeFunctionFlags::TYPE_PARAMS) { + if flags.contains(&bytecode::MakeFunctionFlag::TypeParams) { emit!( self, Instruction::SetFunctionAttribute { - flag: bytecode::MakeFunctionFlags::TYPE_PARAMS + flag: bytecode::MakeFunctionFlag::TypeParams } ); } @@ -4679,14 +4688,14 @@ impl Compiler { emit!(self, Instruction::PushNull); // Set up the class function with type params - let mut func_flags = bytecode::MakeFunctionFlags::empty(); + let mut func_flags = bytecode::MakeFunctionFlags::new(); emit!( self, Instruction::LoadName { namei: dot_type_params } ); - func_flags |= bytecode::MakeFunctionFlags::TYPE_PARAMS; + func_flags.insert(bytecode::MakeFunctionFlag::TypeParams); // Create class function with closure self.make_closure(class_code, func_flags)?; @@ -4809,7 +4818,7 @@ impl Compiler { self.ctx = saved_ctx; // Execute the type params function - self.make_closure(type_params_code, bytecode::MakeFunctionFlags::empty())?; + self.make_closure(type_params_code, bytecode::MakeFunctionFlags::new())?; emit!(self, Instruction::PushNull); emit!(self, Instruction::Call { argc: 0 }); } else { @@ -4818,7 +4827,7 @@ impl Compiler { emit!(self, Instruction::PushNull); // Create class function with closure - self.make_closure(class_code, bytecode::MakeFunctionFlags::empty())?; + self.make_closure(class_code, bytecode::MakeFunctionFlags::new())?; self.emit_load_const(ConstantData::Str { value: name.into() }); if let Some(arguments) = arguments { @@ -7086,12 +7095,12 @@ impl Compiler { } self.enter_function(&name, params)?; - let mut func_flags = bytecode::MakeFunctionFlags::empty(); + let mut func_flags = bytecode::MakeFunctionFlags::new(); if have_defaults { - func_flags |= bytecode::MakeFunctionFlags::DEFAULTS; + func_flags.insert(bytecode::MakeFunctionFlag::Defaults); } if have_kwdefaults { - func_flags |= bytecode::MakeFunctionFlags::KW_ONLY_DEFAULTS; + func_flags.insert(bytecode::MakeFunctionFlag::KwOnlyDefaults); } // Set qualname for lambda @@ -7775,7 +7784,7 @@ impl Compiler { self.ctx = prev_ctx; // Create comprehension function with closure - self.make_closure(code, bytecode::MakeFunctionFlags::empty())?; + self.make_closure(code, bytecode::MakeFunctionFlags::new())?; emit!(self, Instruction::PushNull); // Evaluate iterated item: @@ -8075,9 +8084,9 @@ impl Compiler { // fn block_done() - fn arg_constant(&mut self, constant: ConstantData) -> u32 { + fn arg_constant(&mut self, constant: ConstantData) -> oparg::ConstIdx { let info = self.current_code_info(); - info.metadata.consts.insert_full(constant).0.to_u32() + info.metadata.consts.insert_full(constant).0.to_u32().into() } fn emit_load_const(&mut self, constant: ConstantData) { @@ -9069,6 +9078,18 @@ mod tests { fn compile_exec(source: &str) -> CodeObject { let opts = CompileOpts::default(); + compile_exec_with_options(source, opts) + } + + fn compile_exec_optimized(source: &str) -> CodeObject { + let opts = CompileOpts { + optimize: 1, + ..CompileOpts::default() + }; + compile_exec_with_options(source, opts) + } + + fn compile_exec_with_options(source: &str, opts: CompileOpts) -> CodeObject { let source_file = SourceFileBuilder::new("source_path", source).finish(); let parsed = ruff_python_parser::parse( source_file.source_text(), @@ -9137,6 +9158,15 @@ x = Test() and False or False )); } + #[test] + fn test_const_bool_not_op() { + assert_dis_snapshot!(compile_exec_optimized( + "\ +x = not True +" + )); + } + #[test] fn test_nested_double_async_with() { assert_dis_snapshot!(compile_exec( diff --git a/crates/codegen/src/ir.rs b/crates/codegen/src/ir.rs index 43a2dfa5107..8d5fbdb8bde 100644 --- a/crates/codegen/src/ir.rs +++ b/crates/codegen/src/ir.rs @@ -10,7 +10,7 @@ use rustpython_compiler_core::{ bytecode::{ AnyInstruction, Arg, CodeFlags, CodeObject, CodeUnit, CodeUnits, ConstantData, ExceptionTableEntry, InstrDisplayContext, Instruction, InstructionMetadata, Label, OpArg, - PseudoInstruction, PyCodeLocationInfoKind, encode_exception_table, + PseudoInstruction, PyCodeLocationInfoKind, encode_exception_table, oparg, }, varint::{write_signed_varint, write_varint}, }; @@ -693,6 +693,33 @@ impl CodeInfo { None } } + (Instruction::LoadConst { consti }, Instruction::ToBool) => { + let consti = consti.get(curr.arg); + let constant = &self.metadata.consts[consti.as_usize()]; + if let ConstantData::Boolean { .. } = constant { + Some((curr_instr, OpArg::from(consti.as_u32()))) + } else { + None + } + } + (Instruction::LoadConst { consti }, Instruction::UnaryNot) => { + let constant = &self.metadata.consts[consti.get(curr.arg).as_usize()]; + match constant { + ConstantData::Boolean { value } => { + let (const_idx, _) = self + .metadata + .consts + .insert_full(ConstantData::Boolean { value: !value }); + Some(( + (Instruction::LoadConst { + consti: Arg::marker(), + }), + OpArg::new(const_idx as u32), + )) + } + _ => None, + } + } _ => None, } }; @@ -1073,15 +1100,19 @@ impl CodeInfo { impl InstrDisplayContext for CodeInfo { type Constant = ConstantData; - fn get_constant(&self, i: usize) -> &ConstantData { - &self.metadata.consts[i] + + fn get_constant(&self, consti: oparg::ConstIdx) -> &ConstantData { + &self.metadata.consts[consti.as_usize()] } + fn get_name(&self, i: usize) -> &str { self.metadata.names[i].as_ref() } - fn get_varname(&self, i: usize) -> &str { - self.metadata.varnames[i].as_ref() + + fn get_varname(&self, var_num: oparg::VarNum) -> &str { + self.metadata.varnames[var_num.as_usize()].as_ref() } + fn get_cell_name(&self, i: usize) -> &str { self.metadata .cellvars diff --git a/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__const_bool_not_op.snap b/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__const_bool_not_op.snap new file mode 100644 index 00000000000..f9a74c2055c --- /dev/null +++ b/crates/codegen/src/snapshots/rustpython_codegen__compile__tests__const_bool_not_op.snap @@ -0,0 +1,9 @@ +--- +source: crates/codegen/src/compile.rs +expression: "compile_exec_optimized(\"\\\nx = not True\n\")" +--- + 1 0 RESUME (0) + 1 LOAD_CONST (False) + 2 STORE_NAME (0, x) + 3 LOAD_CONST (None) + 4 RETURN_VALUE diff --git a/crates/codegen/src/symboltable.rs b/crates/codegen/src/symboltable.rs index 0d868bc0468..fdbdac2b2a7 100644 --- a/crates/codegen/src/symboltable.rs +++ b/crates/codegen/src/symboltable.rs @@ -2037,20 +2037,13 @@ impl SymbolTableBuilder { self.line_index_start(range), ); - // Mark non-generator comprehensions as inlined (PEP 709) - // inline_comp = entry->ste_comprehension && !entry->ste_generator && !ste->ste_can_see_class_scope - // We check is_generator and can_see_class_scope of parent - let parent_can_see_class = self - .tables - .get(self.tables.len().saturating_sub(2)) - .map(|t| t.can_see_class_scope) - .unwrap_or(false); - if !is_generator - && !parent_can_see_class - && let Some(table) = self.tables.last_mut() - { - table.comp_inlined = true; - } + // PEP 709: inlined comprehensions are not yet implemented in the + // compiler (is_inlined_comprehension_context always returns false), + // so do NOT mark comp_inlined here. Setting it would cause the + // symbol-table analyzer to merge comprehension-local symbols into + // the parent scope, while the compiler still emits a separate code + // object — leading to the merged symbols being missing from the + // comprehension's own symbol table lookup. // Register the passed argument to the generator function as the name ".0" self.register_name(".0", SymbolUsage::Parameter, range)?; diff --git a/crates/common/src/format.rs b/crates/common/src/format.rs index 40bc9e53046..930c764acf3 100644 --- a/crates/common/src/format.rs +++ b/crates/common/src/format.rs @@ -12,6 +12,19 @@ use rustpython_literal::format::Case; use crate::wtf8::{CodePoint, Wtf8, Wtf8Buf}; +/// Locale information for 'n' format specifier. +/// Contains thousands separator, decimal point, and grouping pattern +/// from the C library's `localeconv()`. +#[derive(Clone, Debug)] +pub struct LocaleInfo { + pub thousands_sep: String, + pub decimal_point: String, + /// Grouping pattern from `lconv.grouping`. + /// Each element is a group size. The last non-zero element repeats. + /// e.g. `[3, 0]` means groups of 3 repeating forever. + pub grouping: Vec, +} + trait FormatParse { fn parse(text: &Wtf8) -> (Option, &Wtf8) where @@ -460,6 +473,189 @@ impl FormatSpec { } } + /// Returns true if this format spec uses the locale-aware 'n' format type. + pub fn has_locale_format(&self) -> bool { + matches!(self.format_type, Some(FormatType::Number(Case::Lower))) + } + + /// Insert locale-aware thousands separators into an integer string. + /// Follows CPython's GroupGenerator logic for variable-width grouping. + fn insert_locale_grouping(int_part: &str, locale: &LocaleInfo) -> String { + if locale.grouping.is_empty() || locale.thousands_sep.is_empty() || int_part.len() <= 1 { + return int_part.to_string(); + } + + let mut group_idx = 0; + let mut group_size = locale.grouping[0] as usize; + + if group_size == 0 { + return int_part.to_string(); + } + + // Collect groups of digits from right to left + let len = int_part.len(); + let mut groups: Vec<&str> = Vec::new(); + let mut pos = len; + + loop { + if pos <= group_size { + groups.push(&int_part[..pos]); + break; + } + + groups.push(&int_part[pos - group_size..pos]); + pos -= group_size; + + // Advance to next group size + if group_idx + 1 < locale.grouping.len() { + let next = locale.grouping[group_idx + 1] as usize; + if next != 0 { + group_size = next; + group_idx += 1; + } + // 0 means repeat previous group size forever + } + } + + // Groups were collected right-to-left, reverse to get left-to-right + groups.reverse(); + groups.join(&locale.thousands_sep) + } + + /// Apply locale-aware grouping and decimal point replacement to a formatted number. + fn apply_locale_formatting(magnitude_str: String, locale: &LocaleInfo) -> String { + let mut parts = magnitude_str.splitn(2, '.'); + let int_part = parts.next().unwrap(); + let grouped = Self::insert_locale_grouping(int_part, locale); + + if let Some(frac_part) = parts.next() { + format!("{grouped}{}{frac_part}", locale.decimal_point) + } else { + grouped + } + } + + /// Format an integer with locale-aware 'n' format. + pub fn format_int_locale( + &self, + num: &BigInt, + locale: &LocaleInfo, + ) -> Result { + self.validate_format(FormatType::Decimal)?; + let magnitude = num.abs(); + + let raw_magnitude_str = match self.format_type { + Some(FormatType::Number(Case::Lower)) => self.format_int_radix(magnitude, 10), + _ => return self.format_int(num), + }?; + + let magnitude_str = Self::apply_locale_formatting(raw_magnitude_str, locale); + + let format_sign = self.sign.unwrap_or(FormatSign::Minus); + let sign_str = match num.sign() { + Sign::Minus => "-", + _ => match format_sign { + FormatSign::Plus => "+", + FormatSign::Minus => "", + FormatSign::MinusOrSpace => " ", + }, + }; + + self.format_sign_and_align(&AsciiStr::new(&magnitude_str), sign_str, FormatAlign::Right) + } + + /// Format a float with locale-aware 'n' format. + pub fn format_float_locale( + &self, + num: f64, + locale: &LocaleInfo, + ) -> Result { + self.validate_format(FormatType::FixedPoint(Case::Lower))?; + let precision = self.precision.unwrap_or(6); + let magnitude = num.abs(); + + let raw_magnitude_str = match &self.format_type { + Some(FormatType::Number(case)) => { + let precision = if precision == 0 { 1 } else { precision }; + Ok(float::format_general( + precision, + magnitude, + *case, + self.alternate_form, + false, + )) + } + _ => return self.format_float(num), + }?; + + let magnitude_str = Self::apply_locale_formatting(raw_magnitude_str, locale); + + let format_sign = self.sign.unwrap_or(FormatSign::Minus); + let sign_str = if num.is_sign_negative() && !num.is_nan() { + "-" + } else { + match format_sign { + FormatSign::Plus => "+", + FormatSign::Minus => "", + FormatSign::MinusOrSpace => " ", + } + }; + + self.format_sign_and_align(&AsciiStr::new(&magnitude_str), sign_str, FormatAlign::Right) + } + + /// Format a complex number with locale-aware 'n' format. + pub fn format_complex_locale( + &self, + num: &Complex64, + locale: &LocaleInfo, + ) -> Result { + // Reuse format_complex_re_im with 'g' type to get the base formatted parts, + // then apply locale grouping. This matches CPython's format_complex_internal: + // 'n' → 'g', add_parens=0, skip_re=0. + let locale_spec = FormatSpec { + format_type: Some(FormatType::GeneralFormat(Case::Lower)), + ..*self + }; + let (formatted_re, formatted_im) = locale_spec.format_complex_re_im(num)?; + + // Apply locale grouping to both parts + let grouped_re = if formatted_re.is_empty() { + formatted_re + } else { + // Split sign from magnitude, apply grouping, recombine + let (sign, mag) = if formatted_re.starts_with('-') + || formatted_re.starts_with('+') + || formatted_re.starts_with(' ') + { + formatted_re.split_at(1) + } else { + ("", formatted_re.as_str()) + }; + format!( + "{sign}{}", + Self::apply_locale_formatting(mag.to_string(), locale) + ) + }; + + // formatted_im is like "+1234j" or "-1234j" or "1234j" + // Split sign, magnitude, and 'j' suffix + let im_str = &formatted_im; + let (im_sign, im_rest) = if im_str.starts_with('+') || im_str.starts_with('-') { + im_str.split_at(1) + } else { + ("", im_str.as_str()) + }; + let im_mag = im_rest.strip_suffix('j').unwrap_or(im_rest); + let im_grouped = Self::apply_locale_formatting(im_mag.to_string(), locale); + let grouped_im = format!("{im_sign}{im_grouped}j"); + + // No parentheses for 'n' format (CPython: add_parens=0) + let magnitude_str = format!("{grouped_re}{grouped_im}"); + + self.format_sign_and_align(&AsciiStr::new(&magnitude_str), "", FormatAlign::Right) + } + pub fn format_bool(&self, input: bool) -> Result { let x = u8::from(input); match &self.format_type { diff --git a/crates/common/src/lock.rs b/crates/common/src/lock.rs index af680010821..cd7df512d83 100644 --- a/crates/common/src/lock.rs +++ b/crates/common/src/lock.rs @@ -68,32 +68,37 @@ pub type PyMappedRwLockWriteGuard<'a, T> = MappedRwLockWriteGuard<'a, RawRwLock, // can add fn const_{mutex,rw_lock}() if necessary, but we probably won't need to -/// Reset a `PyMutex` to its initial (unlocked) state after `fork()`. +/// Reset a lock to its initial (unlocked) state by zeroing its bytes. /// -/// After `fork()`, locks held by dead parent threads would deadlock in the -/// child. This writes `RawMutex::INIT` via the `Mutex::raw()` accessor, -/// bypassing the normal unlock path which may interact with parking_lot's -/// internal waiter queues. +/// After `fork()`, any lock held by a now-dead thread would remain +/// permanently locked. We zero the raw bytes (the unlocked state for all +/// `parking_lot` raw lock types) instead of using the normal unlock path, +/// which would interact with stale waiter queues. /// /// # Safety /// /// Must only be called from the single-threaded child process immediately /// after `fork()`, before any other thread is created. -#[cfg(unix)] -pub unsafe fn reinit_mutex_after_fork(mutex: &PyMutex) { - // Use Mutex::raw() to access the underlying lock without layout assumptions. - // parking_lot::RawMutex (AtomicU8) and RawCellMutex (Cell) both - // represent the unlocked state as all-zero bytes. +/// The type `T` must represent the unlocked state as all-zero bytes +/// (true for `parking_lot::RawMutex`, `RawRwLock`, `RawReentrantMutex`, etc.). +pub unsafe fn zero_reinit_after_fork(lock: *const T) { unsafe { - let raw = mutex.raw() as *const RawMutex as *mut u8; - core::ptr::write_bytes(raw, 0, core::mem::size_of::()); + core::ptr::write_bytes(lock as *mut u8, 0, core::mem::size_of::()); } } -/// Reset a `PyRwLock` to its initial (unlocked) state after `fork()`. +/// Reset a `PyMutex` after `fork()`. See [`zero_reinit_after_fork`]. +/// +/// # Safety /// -/// Same rationale as [`reinit_mutex_after_fork`] — dead threads' read or -/// write locks would cause permanent deadlock in the child. +/// Must only be called from the single-threaded child process immediately +/// after `fork()`, before any other thread is created. +#[cfg(unix)] +pub unsafe fn reinit_mutex_after_fork(mutex: &PyMutex) { + unsafe { zero_reinit_after_fork(mutex.raw()) } +} + +/// Reset a `PyRwLock` after `fork()`. See [`zero_reinit_after_fork`]. /// /// # Safety /// @@ -101,10 +106,7 @@ pub unsafe fn reinit_mutex_after_fork(mutex: &PyMutex) { /// after `fork()`, before any other thread is created. #[cfg(unix)] pub unsafe fn reinit_rwlock_after_fork(rwlock: &PyRwLock) { - unsafe { - let raw = rwlock.raw() as *const RawRwLock as *mut u8; - core::ptr::write_bytes(raw, 0, core::mem::size_of::()); - } + unsafe { zero_reinit_after_fork(rwlock.raw()) } } /// Reset a `PyThreadMutex` to its initial (unlocked, unowned) state after `fork()`. diff --git a/crates/common/src/lock/thread_mutex.rs b/crates/common/src/lock/thread_mutex.rs index 5b5b89f4eb1..884556c4476 100644 --- a/crates/common/src/lock/thread_mutex.rs +++ b/crates/common/src/lock/thread_mutex.rs @@ -54,6 +54,18 @@ impl RawThreadMutex { .is_some() } + /// Like `lock()` but wraps the blocking wait in `wrap_fn`. + /// The caller can use this to detach thread state while waiting. + pub fn lock_wrapped(&self, wrap_fn: F) -> bool { + let id = self.get_thread_id.nonzero_thread_id().get(); + if self.owner.load(Ordering::Relaxed) == id { + return false; + } + wrap_fn(&|| self.mutex.lock()); + self.owner.store(id, Ordering::Relaxed); + true + } + /// Returns `Some(true)` if able to successfully lock without blocking, `Some(false)` /// otherwise, and `None` when the mutex is already locked on the current thread. pub fn try_lock(&self) -> Option { @@ -135,6 +147,23 @@ impl ThreadMutex { None } } + + /// Like `lock()` but wraps the blocking wait in `wrap_fn`. + /// The caller can use this to detach thread state while waiting. + pub fn lock_wrapped( + &self, + wrap_fn: F, + ) -> Option> { + if self.raw.lock_wrapped(wrap_fn) { + Some(ThreadMutexGuard { + mu: self, + marker: PhantomData, + }) + } else { + None + } + } + pub fn try_lock(&self) -> Result, TryLockThreadError> { match self.raw.try_lock() { Some(true) => Ok(ThreadMutexGuard { diff --git a/crates/compiler-core/Cargo.toml b/crates/compiler-core/Cargo.toml index f4e619b95a4..7be58432cdf 100644 --- a/crates/compiler-core/Cargo.toml +++ b/crates/compiler-core/Cargo.toml @@ -14,6 +14,7 @@ ruff_source_file = { workspace = true } rustpython-wtf8 = { workspace = true } bitflags = { workspace = true } +bitflagset = { workspace = true } itertools = { workspace = true } malachite-bigint = { workspace = true } num-complex = { workspace = true } diff --git a/crates/compiler-core/src/bytecode.rs b/crates/compiler-core/src/bytecode.rs index 46182962654..63801ba519a 100644 --- a/crates/compiler-core/src/bytecode.rs +++ b/crates/compiler-core/src/bytecode.rs @@ -11,7 +11,7 @@ use bitflags::bitflags; use core::{ cell::UnsafeCell, hash, mem, - ops::Deref, + ops::{Deref, Index, IndexMut}, sync::atomic::{AtomicU8, AtomicU16, AtomicUsize, Ordering}, }; use itertools::Itertools; @@ -26,13 +26,13 @@ pub use crate::bytecode::{ oparg::{ BinaryOperator, BuildSliceArgCount, CommonConstant, ComparisonOperator, ConvertValueOparg, IntrinsicFunction1, IntrinsicFunction2, Invert, Label, LoadAttr, LoadSuperAttr, - MakeFunctionFlags, NameIdx, OpArg, OpArgByte, OpArgState, OpArgType, RaiseKind, ResumeType, - SpecialMethod, UnpackExArgs, + MakeFunctionFlag, MakeFunctionFlags, NameIdx, OpArg, OpArgByte, OpArgState, OpArgType, + RaiseKind, ResumeType, SpecialMethod, UnpackExArgs, }, }; mod instruction; -mod oparg; +pub mod oparg; /// Exception table entry for zero-cost exception handling /// Format: (start, size, target, depth<<1|lasti) @@ -293,6 +293,47 @@ impl ConstantBag for BasicBag { } } +#[derive(Clone)] +pub struct Constants(Box<[C]>); + +impl Deref for Constants { + type Target = [C]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Index for Constants { + type Output = C; + + fn index(&self, consti: oparg::ConstIdx) -> &Self::Output { + &self.0[consti.as_usize()] + } +} + +impl FromIterator for Constants { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +// TODO: Newtype "CodeObject.varnames". Make sure only `oparg:VarNum` can be used as index +impl Index for [T] { + type Output = T; + + fn index(&self, var_num: oparg::VarNum) -> &Self::Output { + &self[var_num.as_usize()] + } +} + +// TODO: Newtype "CodeObject.varnames". Make sure only `oparg:VarNum` can be used as index +impl IndexMut for [T] { + fn index_mut(&mut self, var_num: oparg::VarNum) -> &mut Self::Output { + &mut self[var_num.as_usize()] + } +} + /// Primary container of a single code object. Each python function has /// a code object. Also a module has a code object. #[derive(Clone)] @@ -312,7 +353,7 @@ pub struct CodeObject { /// Qualified name of the object (like CPython's co_qualname) pub qualname: C::Name, pub cell2arg: Option>, - pub constants: Box<[C]>, + pub constants: Constants, pub names: Box<[C::Name]>, pub varnames: Box<[C::Name]>, pub cellvars: Box<[C::Name]>, @@ -1012,16 +1053,14 @@ impl CodeObject { pub fn map_bag(self, bag: Bag) -> CodeObject { let map_names = |names: Box<[C::Name]>| { names - .into_vec() - .into_iter() + .iter() .map(|x| bag.make_name(x.as_ref())) .collect::>() }; CodeObject { constants: self .constants - .into_vec() - .into_iter() + .iter() .map(|x| bag.make_constant(x.borrow_constant())) .collect(), names: map_names(self.names), @@ -1095,11 +1134,11 @@ impl fmt::Display for CodeObject { pub trait InstrDisplayContext { type Constant: Constant; - fn get_constant(&self, i: usize) -> &Self::Constant; + fn get_constant(&self, consti: oparg::ConstIdx) -> &Self::Constant; fn get_name(&self, i: usize) -> &str; - fn get_varname(&self, i: usize) -> &str; + fn get_varname(&self, var_num: oparg::VarNum) -> &str; fn get_cell_name(&self, i: usize) -> &str; } @@ -1107,16 +1146,16 @@ pub trait InstrDisplayContext { impl InstrDisplayContext for CodeObject { type Constant = C; - fn get_constant(&self, i: usize) -> &C { - &self.constants[i] + fn get_constant(&self, consti: oparg::ConstIdx) -> &C { + &self.constants[consti] } fn get_name(&self, i: usize) -> &str { self.names[i].as_ref() } - fn get_varname(&self, i: usize) -> &str { - self.varnames[i].as_ref() + fn get_varname(&self, var_num: oparg::VarNum) -> &str { + self.varnames[var_num].as_ref() } fn get_cell_name(&self, i: usize) -> &str { diff --git a/crates/compiler-core/src/bytecode/instruction.rs b/crates/compiler-core/src/bytecode/instruction.rs index 754447956fa..90c1f68f9ed 100644 --- a/crates/compiler-core/src/bytecode/instruction.rs +++ b/crates/compiler-core/src/bytecode/instruction.rs @@ -4,9 +4,9 @@ use crate::{ bytecode::{ BorrowedConstant, Constant, InstrDisplayContext, oparg::{ - BinaryOperator, BuildSliceArgCount, CommonConstant, ComparisonOperator, + self, BinaryOperator, BuildSliceArgCount, CommonConstant, ComparisonOperator, ConvertValueOparg, IntrinsicFunction1, IntrinsicFunction2, Invert, Label, LoadAttr, - LoadSuperAttr, MakeFunctionFlags, NameIdx, OpArg, OpArgByte, OpArgType, RaiseKind, + LoadSuperAttr, MakeFunctionFlag, NameIdx, OpArg, OpArgByte, OpArgType, RaiseKind, SpecialMethod, StoreFastLoadFast, UnpackExArgs, }, }, @@ -133,7 +133,7 @@ pub enum Instruction { i: Arg, } = 62, DeleteFast { - var_num: Arg, + var_num: Arg, } = 63, DeleteGlobal { namei: Arg, @@ -186,25 +186,25 @@ pub enum Instruction { idx: Arg, } = 81, LoadConst { - consti: Arg, + consti: Arg, } = 82, LoadDeref { i: Arg, } = 83, LoadFast { - var_num: Arg, + var_num: Arg, } = 84, LoadFastAndClear { - var_num: Arg, + var_num: Arg, } = 85, LoadFastBorrow { - var_num: Arg, + var_num: Arg, } = 86, LoadFastBorrowLoadFastBorrow { var_nums: Arg, } = 87, LoadFastCheck { - var_num: Arg, + var_num: Arg, } = 88, LoadFastLoadFast { var_nums: Arg, @@ -264,7 +264,7 @@ pub enum Instruction { i: Arg, } = 107, SetFunctionAttribute { - flag: Arg, + flag: Arg, } = 108, SetUpdate { i: Arg, @@ -276,7 +276,7 @@ pub enum Instruction { i: Arg, } = 111, StoreFast { - var_num: Arg, + var_num: Arg, } = 112, StoreFastLoadFast { var_nums: Arg, @@ -1105,7 +1105,13 @@ impl InstructionMetadata for Instruction { }; ($variant:ident, $map:ident = $arg_marker:expr) => {{ let arg = $arg_marker.get(arg); - write!(f, "{:pad$}({}, {})", stringify!($variant), arg, $map(arg)) + write!( + f, + "{:pad$}({}, {})", + stringify!($variant), + u32::from(arg), + $map(arg) + ) }}; ($variant:ident, $arg_marker:expr) => { write!(f, "{:pad$}({})", stringify!($variant), $arg_marker.get(arg)) @@ -1120,26 +1126,29 @@ impl InstructionMetadata for Instruction { }; } - let varname = |i: u32| ctx.get_varname(i as usize); + let varname = |var_num: oparg::VarNum| ctx.get_varname(var_num); let name = |i: u32| ctx.get_name(i as usize); let cell_name = |i: u32| ctx.get_cell_name(i as usize); - let fmt_const = - |op: &str, arg: OpArg, f: &mut fmt::Formatter<'_>, idx: &Arg| -> fmt::Result { - let value = ctx.get_constant(idx.get(arg) as usize); - match value.borrow_constant() { - BorrowedConstant::Code { code } if expand_code_objects => { - write!(f, "{op:pad$}({code:?}):")?; - code.display_inner(f, true, level + 1)?; - Ok(()) - } - c => { - write!(f, "{op:pad$}(")?; - c.fmt_display(f)?; - write!(f, ")") - } + let fmt_const = |op: &str, + arg: OpArg, + f: &mut fmt::Formatter<'_>, + consti: &Arg| + -> fmt::Result { + let value = ctx.get_constant(consti.get(arg)); + match value.borrow_constant() { + BorrowedConstant::Code { code } if expand_code_objects => { + write!(f, "{op:pad$}({code:?}):")?; + code.display_inner(f, true, level + 1)?; + Ok(()) } - }; + c => { + write!(f, "{op:pad$}(")?; + c.fmt_display(f)?; + write!(f, ")") + } + } + }; match self { Self::BinarySlice => w!(BINARY_SLICE), @@ -1223,16 +1232,16 @@ impl InstructionMetadata for Instruction { let oparg = var_nums.get(arg); let idx1 = oparg >> 4; let idx2 = oparg & 15; - let name1 = varname(idx1); - let name2 = varname(idx2); + let name1 = varname(idx1.into()); + let name2 = varname(idx2.into()); write!(f, "{:pad$}({}, {})", "LOAD_FAST_LOAD_FAST", name1, name2) } Self::LoadFastBorrowLoadFastBorrow { var_nums } => { let oparg = var_nums.get(arg); let idx1 = oparg >> 4; let idx2 = oparg & 15; - let name1 = varname(idx1); - let name2 = varname(idx2); + let name1 = varname(idx1.into()); + let name2 = varname(idx2.into()); write!( f, "{:pad$}({}, {})", @@ -1359,8 +1368,8 @@ impl InstructionMetadata for Instruction { f, "{:pad$}({}, {})", "STORE_FAST_STORE_FAST", - varname(idx1), - varname(idx2) + varname(idx1.into()), + varname(idx2.into()) ) } Self::StoreGlobal { namei } => w!(STORE_GLOBAL, name = namei), diff --git a/crates/compiler-core/src/bytecode/oparg.rs b/crates/compiler-core/src/bytecode/oparg.rs index 729b84db591..49bd715c459 100644 --- a/crates/compiler-core/src/bytecode/oparg.rs +++ b/crates/compiler-core/src/bytecode/oparg.rs @@ -1,5 +1,3 @@ -use bitflags::bitflags; - use core::fmt; use crate::{ @@ -423,34 +421,40 @@ oparg_enum!( } ); -bitflags! { - #[derive(Copy, Clone, Debug, PartialEq)] - pub struct MakeFunctionFlags: u8 { - const CLOSURE = 0x01; - const ANNOTATIONS = 0x02; - const KW_ONLY_DEFAULTS = 0x04; - const DEFAULTS = 0x08; - const TYPE_PARAMS = 0x10; +bitflagset::bitflag! { + #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] + #[repr(u8)] + pub enum MakeFunctionFlag { + Closure = 0, + Annotations = 1, + KwOnlyDefaults = 2, + Defaults = 3, + TypeParams = 4, /// PEP 649: __annotate__ function closure (instead of __annotations__ dict) - const ANNOTATE = 0x20; + Annotate = 5, } } -impl TryFrom for MakeFunctionFlags { +bitflagset::bitflagset! { + #[derive(Copy, Clone, PartialEq, Eq)] + pub struct MakeFunctionFlags(u8): MakeFunctionFlag +} + +impl TryFrom for MakeFunctionFlag { type Error = MarshalError; fn try_from(value: u32) -> Result { - Self::from_bits(value as u8).ok_or(Self::Error::InvalidBytecode) + Self::try_from(value as u8).map_err(|_| MarshalError::InvalidBytecode) } } -impl From for u32 { - fn from(value: MakeFunctionFlags) -> Self { - value.bits().into() +impl From for u32 { + fn from(flag: MakeFunctionFlag) -> Self { + flag as u32 } } -impl OpArgType for MakeFunctionFlags {} +impl OpArgType for MakeFunctionFlag {} oparg_enum!( /// The possible comparison operators. @@ -872,3 +876,56 @@ impl LoadAttrBuilder { self } } + +macro_rules! newtype_oparg { + ( + $(#[$oparg_meta:meta])* + $vis:vis struct $name:ident(u32) + ) => { + $(#[$oparg_meta])* + $vis struct $name(u32); + + impl $name { + #[must_use] + pub const fn from_u32(value: u32) -> Self { + Self(value) + } + + /// Returns the oparg as a `u32` value. + #[must_use] + pub const fn as_u32(self) -> u32 { + self.0 + } + + /// Returns the oparg as a `usize` value. + #[must_use] + pub const fn as_usize(self) -> usize { + self.0 as usize + } + } + + impl From for $name { + fn from(value: u32) -> Self { + Self::from_u32(value) + } + } + + impl From<$name> for u32 { + fn from(value: $name) -> Self { + value.as_u32() + } + } + + impl OpArgType for $name {} + } +} + +newtype_oparg!( + #[derive(Clone, Copy)] + pub struct ConstIdx(u32) +); + +newtype_oparg!( + #[derive(Clone, Copy)] + pub struct VarNum(u32) +); diff --git a/crates/compiler-core/src/marshal.rs b/crates/compiler-core/src/marshal.rs index 310bad9d868..ba3cf7a35c3 100644 --- a/crates/compiler-core/src/marshal.rs +++ b/crates/compiler-core/src/marshal.rs @@ -240,7 +240,7 @@ pub fn deserialize_code( let len = rdr.read_u32()?; let constants = (0..len) .map(|_| deserialize_value(rdr, bag)) - .collect::>>()?; + .collect::>()?; let mut read_names = || { let len = rdr.read_u32()?; diff --git a/crates/derive-impl/src/pyclass.rs b/crates/derive-impl/src/pyclass.rs index dfb02a3eda8..1fec51ddd42 100644 --- a/crates/derive-impl/src/pyclass.rs +++ b/crates/derive-impl/src/pyclass.rs @@ -1,8 +1,8 @@ use super::Diagnostic; use crate::util::{ ALL_ALLOWED_NAMES, ClassItemMeta, ContentItem, ContentItemInner, ErrorVec, ExceptionItemMeta, - ItemMeta, ItemMetaInner, ItemNursery, SimpleItemMeta, format_doc, pyclass_ident_and_attrs, - pyexception_ident_and_attrs, text_signature, + ItemMeta, ItemMetaInner, ItemNursery, SimpleItemMeta, format_doc, infer_native_call_flags, + pyclass_ident_and_attrs, pyexception_ident_and_attrs, text_signature, }; use core::str::FromStr; use proc_macro2::{Delimiter, Group, Span, TokenStream, TokenTree}; @@ -1015,6 +1015,16 @@ where let raw = item_meta.raw()?; let sig_doc = text_signature(func.sig(), &py_name); + let has_receiver = func + .sig() + .inputs + .iter() + .any(|arg| matches!(arg, syn::FnArg::Receiver(_))); + let drop_first_typed = match self.inner.attr_name { + AttrName::Method | AttrName::ClassMethod if !has_receiver && !raw => 1, + _ => 0, + }; + let call_flags = infer_native_call_flags(func.sig(), drop_first_typed); // Add #[allow(non_snake_case)] for setter methods like set___name__ let method_name = ident.to_string(); @@ -1031,6 +1041,7 @@ where doc, raw, attr_name: self.inner.attr_name, + call_flags, }); Ok(()) } @@ -1248,6 +1259,7 @@ struct MethodNurseryItem { raw: bool, doc: Option, attr_name: AttrName, + call_flags: TokenStream, } impl MethodNursery { @@ -1278,7 +1290,7 @@ impl ToTokens for MethodNursery { } else { quote! { None } }; - let flags = match &item.attr_name { + let binding_flags = match &item.attr_name { AttrName::Method => { quote! { rustpython_vm::function::PyMethodFlags::METHOD } } @@ -1290,6 +1302,12 @@ impl ToTokens for MethodNursery { } _ => unreachable!(), }; + let call_flags = &item.call_flags; + let flags = quote! { + rustpython_vm::function::PyMethodFlags::from_bits_retain( + (#binding_flags).bits() | (#call_flags).bits() + ) + }; // TODO: intern // let py_name = if py_name.starts_with("__") && py_name.ends_with("__") { // let name_ident = Ident::new(&py_name, ident.span()); diff --git a/crates/derive-impl/src/pymodule.rs b/crates/derive-impl/src/pymodule.rs index 775e6858520..b4b5535200c 100644 --- a/crates/derive-impl/src/pymodule.rs +++ b/crates/derive-impl/src/pymodule.rs @@ -2,8 +2,8 @@ use crate::error::Diagnostic; use crate::pystructseq::PyStructSequenceMeta; use crate::util::{ ALL_ALLOWED_NAMES, AttrItemMeta, AttributeExt, ClassItemMeta, ContentItem, ContentItemInner, - ErrorVec, ItemMeta, ItemNursery, ModuleItemMeta, SimpleItemMeta, format_doc, iter_use_idents, - pyclass_ident_and_attrs, text_signature, + ErrorVec, ItemMeta, ItemNursery, ModuleItemMeta, SimpleItemMeta, format_doc, + infer_native_call_flags, iter_use_idents, pyclass_ident_and_attrs, text_signature, }; use core::str::FromStr; use proc_macro2::{Delimiter, Group, TokenStream, TokenTree}; @@ -525,6 +525,7 @@ struct FunctionNurseryItem { cfgs: Vec, ident: Ident, doc: String, + call_flags: TokenStream, } impl FunctionNursery { @@ -550,7 +551,6 @@ struct ValidatedFunctionNursery(FunctionNursery); impl ToTokens for ValidatedFunctionNursery { fn to_tokens(&self, tokens: &mut TokenStream) { let mut inner_tokens = TokenStream::new(); - let flags = quote! { rustpython_vm::function::PyMethodFlags::empty() }; for item in &self.0.items { let ident = &item.ident; let cfgs = &item.cfgs; @@ -558,6 +558,7 @@ impl ToTokens for ValidatedFunctionNursery { let py_names = &item.py_names; let doc = &item.doc; let doc = quote!(Some(#doc)); + let flags = &item.call_flags; inner_tokens.extend(quote![ #( @@ -706,12 +707,14 @@ impl ModuleItem for FunctionItem { py_names } }; + let call_flags = infer_native_call_flags(func.sig(), 0); args.context.function_items.add_item(FunctionNurseryItem { ident: ident.to_owned(), py_names, cfgs: args.cfgs.to_vec(), doc, + call_flags, }); Ok(()) } diff --git a/crates/derive-impl/src/util.rs b/crates/derive-impl/src/util.rs index a4bf7e6a8fe..068bde9bccd 100644 --- a/crates/derive-impl/src/util.rs +++ b/crates/derive-impl/src/util.rs @@ -732,6 +732,77 @@ pub(crate) fn text_signature(sig: &Signature, name: &str) -> String { } } +pub(crate) fn infer_native_call_flags(sig: &Signature, drop_first_typed: usize) -> TokenStream { + // Best-effort mapping of Rust function signatures to CPython-style + // METH_* calling convention flags used by CALL specialization. + let mut typed_args = Vec::new(); + for arg in &sig.inputs { + let syn::FnArg::Typed(typed) = arg else { + continue; + }; + let ty_tokens = &typed.ty; + let ty = quote!(#ty_tokens).to_string().replace(' ', ""); + // `vm: &VirtualMachine` is not a Python-level argument. + if ty.starts_with('&') && ty.ends_with("VirtualMachine") { + continue; + } + typed_args.push(ty); + } + + let mut user_args = typed_args.into_iter(); + for _ in 0..drop_first_typed { + if user_args.next().is_none() { + break; + } + } + + let mut has_keywords = false; + let mut variable_arity = false; + let mut fixed_positional = 0usize; + + for ty in user_args { + let is_named = |name: &str| { + ty == name + || ty.starts_with(&format!("{name}<")) + || ty.contains(&format!("::{name}<")) + || ty.ends_with(&format!("::{name}")) + }; + + if is_named("FuncArgs") { + has_keywords = true; + variable_arity = true; + continue; + } + if is_named("KwArgs") { + has_keywords = true; + variable_arity = true; + continue; + } + if is_named("PosArgs") || is_named("OptionalArg") || is_named("OptionalOption") { + variable_arity = true; + continue; + } + fixed_positional += 1; + } + + if has_keywords { + quote! { + rustpython_vm::function::PyMethodFlags::from_bits_retain( + rustpython_vm::function::PyMethodFlags::FASTCALL.bits() + | rustpython_vm::function::PyMethodFlags::KEYWORDS.bits() + ) + } + } else if variable_arity { + quote! { rustpython_vm::function::PyMethodFlags::FASTCALL } + } else { + match fixed_positional { + 0 => quote! { rustpython_vm::function::PyMethodFlags::NOARGS }, + 1 => quote! { rustpython_vm::function::PyMethodFlags::O }, + _ => quote! { rustpython_vm::function::PyMethodFlags::FASTCALL }, + } + } +} + fn func_sig(sig: &Signature) -> String { sig.inputs .iter() diff --git a/crates/jit/src/instructions.rs b/crates/jit/src/instructions.rs index 9d8be5bc6e3..a215710da3b 100644 --- a/crates/jit/src/instructions.rs +++ b/crates/jit/src/instructions.rs @@ -6,7 +6,7 @@ use cranelift::prelude::*; use num_traits::cast::ToPrimitive; use rustpython_compiler_core::bytecode::{ self, BinaryOperator, BorrowedConstant, CodeObject, ComparisonOperator, Instruction, - IntrinsicFunction1, Label, OpArg, OpArgState, + IntrinsicFunction1, Label, OpArg, OpArgState, oparg, }; use std::collections::HashMap; @@ -94,7 +94,10 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { let params = compiler.builder.func.dfg.block_params(entry_block).to_vec(); for (i, (ty, val)) in arg_types.iter().zip(params).enumerate() { compiler - .store_variable(i as u32, JitValue::from_type_and_value(ty.clone(), val)) + .store_variable( + (i as u32).into(), + JitValue::from_type_and_value(ty.clone(), val), + ) .unwrap(); } compiler @@ -105,14 +108,10 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { self.stack.drain(stack_len - count..).collect() } - fn store_variable( - &mut self, - idx: bytecode::NameIdx, - val: JitValue, - ) -> Result<(), JitCompileError> { + fn store_variable(&mut self, idx: oparg::VarNum, val: JitValue) -> Result<(), JitCompileError> { let builder = &mut self.builder; let ty = val.to_jit_type().ok_or(JitCompileError::NotSupported)?; - let local = self.variables[idx as usize].get_or_insert_with(|| { + let local = self.variables[idx].get_or_insert_with(|| { let var = builder.declare_var(ty.to_cranelift()); Local { var, @@ -637,9 +636,8 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { Ok(()) } Instruction::LoadConst { consti } => { - let val = self.prepare_const( - bytecode.constants[consti.get(arg) as usize].borrow_constant(), - )?; + let val = + self.prepare_const(bytecode.constants[consti.get(arg)].borrow_constant())?; self.stack.push(val); Ok(()) } @@ -650,7 +648,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { Ok(()) } Instruction::LoadFast { var_num } | Instruction::LoadFastBorrow { var_num } => { - let local = self.variables[var_num.get(arg) as usize] + let local = self.variables[var_num.get(arg)] .as_ref() .ok_or(JitCompileError::BadBytecode)?; self.stack.push(JitValue::from_type_and_value( diff --git a/crates/jit/tests/common.rs b/crates/jit/tests/common.rs index 095bb87d7be..629cdccc7fd 100644 --- a/crates/jit/tests/common.rs +++ b/crates/jit/tests/common.rs @@ -1,6 +1,6 @@ use core::ops::ControlFlow; use rustpython_compiler_core::bytecode::{ - CodeObject, ConstantData, Instruction, OpArg, OpArgState, + CodeObject, ConstantData, Constants, Instruction, OpArg, OpArgState, }; use rustpython_jit::{CompiledCode, JitType}; use rustpython_wtf8::{Wtf8, Wtf8Buf}; @@ -77,7 +77,7 @@ fn extract_annotations_from_annotate_code(code: &CodeObject) -> HashMap { - stack.push((true, consti.get(arg) as usize)); + stack.push((true, consti.get(arg).as_usize())); } Instruction::LoadName { namei } => { stack.push((false, namei.get(arg) as usize)); @@ -99,7 +99,8 @@ fn extract_annotations_from_annotate_code(code: &CodeObject) -> HashMap, names: &[String], ) -> ControlFlow<()> { match instruction { @@ -193,8 +194,7 @@ impl StackMachine { // No-op for JIT tests } Instruction::LoadConst { consti } => { - let idx = consti.get(arg); - self.stack.push(constants[idx as usize].clone().into()) + self.stack.push(constants[consti.get(arg)].clone().into()) } Instruction::LoadName { namei } => self .stack @@ -243,42 +243,40 @@ impl StackMachine { }; let attr_value = self.stack.pop().expect("Expected attribute value on stack"); - let flags = flag.get(arg); + let flag_value = flag.get(arg); - // Handle ANNOTATE flag (PEP 649 style - Python 3.14+) - // The attr_value is a function that returns annotations when called - if flags.contains(rustpython_compiler_core::bytecode::MakeFunctionFlags::ANNOTATE) { - if let StackValue::Function(annotate_func) = attr_value { - // Parse the annotate function's bytecode to extract annotations - // The pattern is: LOAD_CONST (key), LOAD_NAME (value), ... BUILD_MAP - let annotate_code = &annotate_func.code; - let annotations = extract_annotations_from_annotate_code(annotate_code); + match flag_value { + rustpython_compiler_core::bytecode::MakeFunctionFlag::Annotate => { + // Handle ANNOTATE flag (PEP 649 style - Python 3.14+) + if let StackValue::Function(annotate_func) = attr_value { + let annotate_code = &annotate_func.code; + let annotations = extract_annotations_from_annotate_code(annotate_code); - let updated_func = Function { - code: func.code, - annotations, - }; - self.stack.push(StackValue::Function(updated_func)); - } else { - panic!("Expected annotate function for ANNOTATE flag"); + let updated_func = Function { + code: func.code, + annotations, + }; + self.stack.push(StackValue::Function(updated_func)); + } else { + panic!("Expected annotate function for ANNOTATE flag"); + } } - } - // Handle old ANNOTATIONS flag (Python 3.12 style) - else if flags - .contains(rustpython_compiler_core::bytecode::MakeFunctionFlags::ANNOTATIONS) - { - if let StackValue::Map(annotations) = attr_value { - let updated_func = Function { - code: func.code, - annotations, - }; - self.stack.push(StackValue::Function(updated_func)); - } else { - panic!("Expected annotations to be a map"); + rustpython_compiler_core::bytecode::MakeFunctionFlag::Annotations => { + // Handle old ANNOTATIONS flag (Python 3.12 style) + if let StackValue::Map(annotations) = attr_value { + let updated_func = Function { + code: func.code, + annotations, + }; + self.stack.push(StackValue::Function(updated_func)); + } else { + panic!("Expected annotations to be a map"); + } + } + _ => { + // For other attributes, just push the function back unchanged + self.stack.push(StackValue::Function(func)); } - } else { - // For other attributes, just push the function back unchanged - self.stack.push(StackValue::Function(func)); } } Instruction::ReturnValue => return ControlFlow::Break(()), diff --git a/crates/stdlib/Cargo.toml b/crates/stdlib/Cargo.toml index deb1c332f93..36967095639 100644 --- a/crates/stdlib/Cargo.toml +++ b/crates/stdlib/Cargo.toml @@ -31,6 +31,11 @@ rustpython-derive = { workspace = true } rustpython-vm = { workspace = true, default-features = false, features = ["compiler"]} rustpython-common = { workspace = true } +ruff_python_parser = { workspace = true } +ruff_python_ast = { workspace = true } +ruff_text_size = { workspace = true } +ruff_source_file = { workspace = true } + ahash = { workspace = true } ascii = { workspace = true } cfg-if = { workspace = true } @@ -102,7 +107,7 @@ chrono.workspace = true # uuid [target.'cfg(not(any(target_os = "ios", target_os = "android", target_os = "windows", target_arch = "wasm32", target_os = "redox")))'.dependencies] mac_address = "1.1.3" -uuid = { version = "1.21.0", features = ["v1"] } +uuid = { version = "1.22.0", features = ["v1"] } [target.'cfg(all(unix, not(target_os = "redox"), not(target_os = "ios")))'.dependencies] termios = "0.3.3" @@ -115,7 +120,7 @@ rustix = { workspace = true } memmap2 = "0.9.9" page_size = "0.6" gethostname = "1.0.2" -socket2 = { version = "0.6.0", features = ["all"] } +socket2 = { version = "0.6.3", features = ["all"] } dns-lookup = "3.0" # OpenSSL dependencies (optional, for ssl-openssl feature) @@ -125,7 +130,7 @@ openssl-probe = { version = "0.2.1", optional = true } foreign-types-shared = { version = "0.1.1", optional = true } # Rustls dependencies (optional, for ssl-rustls feature) -rustls = { version = "0.23.36", default-features = false, features = ["std", "tls12", "aws_lc_rs"], optional = true } +rustls = { version = "0.23.37", default-features = false, features = ["std", "tls12", "aws_lc_rs"], optional = true } rustls-native-certs = { version = "0.8", optional = true } rustls-pemfile = { version = "2.2", optional = true } rustls-platform-verifier = { version = "0.6", optional = true } diff --git a/crates/stdlib/src/_asyncio.rs b/crates/stdlib/src/_asyncio.rs index 2733e801251..8b32e4625da 100644 --- a/crates/stdlib/src/_asyncio.rs +++ b/crates/stdlib/src/_asyncio.rs @@ -151,7 +151,7 @@ pub(crate) mod _asyncio { fn init(zelf: PyRef, args: Self::Args, vm: &VirtualMachine) -> PyResult<()> { // Future does not accept positional arguments if !args.args.is_empty() { - return Err(vm.new_type_error("Future() takes no positional arguments".to_string())); + return Err(vm.new_type_error("Future() takes no positional arguments")); } // Extract only 'loop' keyword argument let loop_ = args.kwargs.get("loop").cloned(); @@ -160,7 +160,7 @@ pub(crate) mod _asyncio { } #[pyclass( - flags(BASETYPE, HAS_DICT), + flags(BASETYPE, HAS_DICT, HAS_WEAKREF), with(Constructor, Initializer, Destructor, Representable, Iterable) )] impl PyFuture { @@ -265,7 +265,7 @@ pub(crate) mod _asyncio { #[pymethod] fn set_result(zelf: PyRef, result: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { if zelf.fut_loop.read().is_none() { - return Err(vm.new_runtime_error("Future object is not initialized.".to_string())); + return Err(vm.new_runtime_error("Future object is not initialized.")); } if zelf.fut_state.load() != FutureState::Pending { return Err(new_invalid_state_error(vm, "invalid state")); @@ -283,7 +283,7 @@ pub(crate) mod _asyncio { vm: &VirtualMachine, ) -> PyResult<()> { if zelf.fut_loop.read().is_none() { - return Err(vm.new_runtime_error("Future object is not initialized.".to_string())); + return Err(vm.new_runtime_error("Future object is not initialized.")); } if zelf.fut_state.load() != FutureState::Pending { return Err(new_invalid_state_error(vm, "invalid state")); @@ -336,7 +336,7 @@ pub(crate) mod _asyncio { vm: &VirtualMachine, ) -> PyResult<()> { if zelf.fut_loop.read().is_none() { - return Err(vm.new_runtime_error("Future object is not initialized.".to_string())); + return Err(vm.new_runtime_error("Future object is not initialized.")); } let ctx = match args.context.flatten() { Some(c) => c, @@ -364,7 +364,7 @@ pub(crate) mod _asyncio { #[pymethod] fn remove_done_callback(&self, func: PyObjectRef, vm: &VirtualMachine) -> PyResult { if self.fut_loop.read().is_none() { - return Err(vm.new_runtime_error("Future object is not initialized.".to_string())); + return Err(vm.new_runtime_error("Future object is not initialized.")); } let mut cleared_callback0 = 0usize; @@ -461,7 +461,7 @@ pub(crate) mod _asyncio { #[pymethod] fn cancel(zelf: PyRef, args: CancelArgs, vm: &VirtualMachine) -> PyResult { if zelf.fut_loop.read().is_none() { - return Err(vm.new_runtime_error("Future object is not initialized.".to_string())); + return Err(vm.new_runtime_error("Future object is not initialized.")); } if zelf.fut_state.load() != FutureState::Pending { // Clear log_tb even when cancel fails @@ -499,7 +499,8 @@ pub(crate) mod _asyncio { } fn make_cancelled_error_impl(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { - if let Some(exc) = self.fut_cancelled_exc.read().clone() + // If a saved CancelledError exists, take it (clearing the stored reference) + if let Some(exc) = self.fut_cancelled_exc.write().take() && let Ok(exc) = exc.downcast::() { return exc; @@ -508,12 +509,10 @@ pub(crate) mod _asyncio { let msg = self.fut_cancel_msg.read().clone(); let args = if let Some(m) = msg { vec![m] } else { vec![] }; - let exc = match get_cancelled_error_type(vm) { + match get_cancelled_error_type(vm) { Ok(cancelled_error) => vm.new_exception(cancelled_error, args), Err(_) => vm.new_runtime_error("cancelled"), - }; - *self.fut_cancelled_exc.write() = Some(exc.clone().into()); - exc + } } fn schedule_callbacks(zelf: &PyRef, vm: &VirtualMachine) -> PyResult<()> { @@ -598,9 +597,7 @@ pub(crate) mod _asyncio { self.fut_blocking.store(v, Ordering::Relaxed); Ok(()) } - PySetterValue::Delete => { - Err(vm.new_attribute_error("cannot delete attribute".to_string())) - } + PySetterValue::Delete => Err(vm.new_attribute_error("cannot delete attribute")), } } @@ -670,16 +667,12 @@ pub(crate) mod _asyncio { match value { PySetterValue::Assign(v) => { if v { - return Err(vm.new_value_error( - "_log_traceback can only be set to False".to_string(), - )); + return Err(vm.new_value_error("_log_traceback can only be set to False")); } self.fut_log_tb.store(false, Ordering::Relaxed); Ok(()) } - PySetterValue::Delete => { - Err(vm.new_attribute_error("cannot delete attribute".to_string())) - } + PySetterValue::Delete => Err(vm.new_attribute_error("cannot delete attribute")), } } @@ -1055,7 +1048,7 @@ pub(crate) mod _asyncio { // Must be a subclass of BaseException if !exc_class.fast_issubclass(vm.ctx.exceptions.base_exception_type) { return Err(vm.new_type_error( - "exceptions must be classes or instances deriving from BaseException, not type".to_string() + "exceptions must be classes or instances deriving from BaseException, not type" )); } @@ -1072,9 +1065,9 @@ pub(crate) mod _asyncio { if let OptionalArg::Present(ref val) = exc_val && !vm.is_none(val) { - return Err(vm.new_type_error( - "instance exception may not have a separate value".to_string(), - )); + return Err( + vm.new_type_error("instance exception may not have a separate value") + ); } exc_type } else { @@ -1169,7 +1162,7 @@ pub(crate) mod _asyncio { } #[pyclass( - flags(BASETYPE, HAS_DICT), + flags(BASETYPE, HAS_DICT, HAS_WEAKREF), with(Constructor, Initializer, Destructor, Representable, Iterable) )] impl PyTask { @@ -1315,7 +1308,8 @@ pub(crate) mod _asyncio { } fn make_cancelled_error_impl(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { - if let Some(exc) = self.base.fut_cancelled_exc.read().clone() + // If a saved CancelledError exists, take it (clearing the stored reference) + if let Some(exc) = self.base.fut_cancelled_exc.write().take() && let Ok(exc) = exc.downcast::() { return exc; @@ -1324,12 +1318,10 @@ pub(crate) mod _asyncio { let msg = self.base.fut_cancel_msg.read().clone(); let args = if let Some(m) = msg { vec![m] } else { vec![] }; - let exc = match get_cancelled_error_type(vm) { + match get_cancelled_error_type(vm) { Ok(cancelled_error) => vm.new_exception(cancelled_error, args), Err(_) => vm.new_runtime_error("cancelled"), - }; - *self.base.fut_cancelled_exc.write() = Some(exc.clone().into()); - exc + } } #[pymethod] @@ -1339,7 +1331,7 @@ pub(crate) mod _asyncio { vm: &VirtualMachine, ) -> PyResult<()> { if zelf.base.fut_loop.read().is_none() { - return Err(vm.new_runtime_error("Future object is not initialized.".to_string())); + return Err(vm.new_runtime_error("Future object is not initialized.")); } let ctx = match args.context.flatten() { Some(c) => c, @@ -1367,7 +1359,7 @@ pub(crate) mod _asyncio { #[pymethod] fn remove_done_callback(&self, func: PyObjectRef, vm: &VirtualMachine) -> PyResult { if self.base.fut_loop.read().is_none() { - return Err(vm.new_runtime_error("Future object is not initialized.".to_string())); + return Err(vm.new_runtime_error("Future object is not initialized.")); } let mut cleared_callback0 = 0usize; @@ -1686,9 +1678,7 @@ pub(crate) mod _asyncio { self.base.fut_blocking.store(v, Ordering::Relaxed); Ok(()) } - PySetterValue::Delete => { - Err(vm.new_attribute_error("cannot delete attribute".to_string())) - } + PySetterValue::Delete => Err(vm.new_attribute_error("cannot delete attribute")), } } @@ -1718,7 +1708,7 @@ pub(crate) mod _asyncio { Ok(()) } PySetterValue::Delete => { - Err(vm.new_attribute_error("can't delete _log_destroy_pending".to_owned())) + Err(vm.new_attribute_error("can't delete _log_destroy_pending")) } } } @@ -1737,16 +1727,12 @@ pub(crate) mod _asyncio { match value { PySetterValue::Assign(v) => { if v { - return Err(vm.new_value_error( - "_log_traceback can only be set to False".to_string(), - )); + return Err(vm.new_value_error("_log_traceback can only be set to False")); } self.base.fut_log_tb.store(false, Ordering::Relaxed); Ok(()) } - PySetterValue::Delete => { - Err(vm.new_attribute_error("cannot delete attribute".to_string())) - } + PySetterValue::Delete => Err(vm.new_attribute_error("cannot delete attribute")), } } @@ -2532,14 +2518,10 @@ pub(crate) mod _asyncio { let running_task = vm.asyncio_running_task.borrow(); match running_task.as_ref() { None => { - return Err(vm.new_runtime_error( - "_leave_task: task is not the current task".to_owned(), - )); + return Err(vm.new_runtime_error("_leave_task: task is not the current task")); } Some(current) if !current.is(&task) => { - return Err(vm.new_runtime_error( - "_leave_task: task is not the current task".to_owned(), - )); + return Err(vm.new_runtime_error("_leave_task: task is not the current task")); } _ => {} } @@ -2777,7 +2759,7 @@ pub(crate) mod _asyncio { .ok_or_else(|| vm.new_attribute_error("CancelledError not found"))?; exc_type .downcast() - .map_err(|_| vm.new_type_error("CancelledError is not a type".to_string())) + .map_err(|_| vm.new_type_error("CancelledError is not a type")) } fn is_cancelled_error(exc: &PyBaseExceptionRef, vm: &VirtualMachine) -> bool { diff --git a/crates/stdlib/src/_remote_debugging.rs b/crates/stdlib/src/_remote_debugging.rs index 57aa9876a01..618ea9fe0a8 100644 --- a/crates/stdlib/src/_remote_debugging.rs +++ b/crates/stdlib/src/_remote_debugging.rs @@ -98,7 +98,7 @@ mod _remote_debugging { type Args = FuncArgs; fn py_new(_cls: &Py, _args: Self::Args, vm: &VirtualMachine) -> PyResult { - Err(vm.new_not_implemented_error("_remote_debugging is not available".to_owned())) + Err(vm.new_not_implemented_error("_remote_debugging is not available")) } } diff --git a/crates/stdlib/src/_sqlite3.rs b/crates/stdlib/src/_sqlite3.rs index 971c4ec13ac..0a889c4d1e3 100644 --- a/crates/stdlib/src/_sqlite3.rs +++ b/crates/stdlib/src/_sqlite3.rs @@ -61,8 +61,8 @@ mod _sqlite3 { }, convert::IntoObject, function::{ - ArgCallable, ArgIterable, Either, FsPath, FuncArgs, OptionalArg, PyComparisonValue, - PySetterValue, + ArgCallable, ArgIterable, FsPath, FuncArgs, OptionalArg, PyComparisonValue, + PySetterValue, TimeoutSeconds, }, object::{Traverse, TraverseFn}, protocol::{ @@ -333,8 +333,8 @@ mod _sqlite3 { struct ConnectArgs { #[pyarg(any)] database: FsPath, - #[pyarg(any, default = Either::A(5.0))] - timeout: Either, + #[pyarg(any, default = TimeoutSeconds::new(5.0))] + timeout: TimeoutSeconds, #[pyarg(any, default = 0)] detect_types: c_int, #[pyarg(any, default = Some(vm.ctx.empty_str.to_owned()))] @@ -976,7 +976,7 @@ mod _sqlite3 { } } - #[pyclass(with(Constructor, Callable, Initializer), flags(BASETYPE))] + #[pyclass(with(Constructor, Callable, Initializer), flags(BASETYPE, HAS_WEAKREF))] impl Connection { fn drop_db(&self) { self.db.lock().take(); @@ -991,10 +991,7 @@ mod _sqlite3 { fn initialize_db(args: &ConnectArgs, vm: &VirtualMachine) -> PyResult { let path = args.database.to_cstring(vm)?; let db = Sqlite::from(SqliteRaw::open(path.as_ptr(), args.uri, vm)?); - let timeout = (match args.timeout { - Either::A(float) => float, - Either::B(int) => int as f64, - } * 1000.0) as c_int; + let timeout = (args.timeout.to_secs_f64() * 1000.0) as c_int; db.busy_timeout(timeout); if let Some(isolation_level) = &args.isolation_level { begin_statement_ptr_from_isolation_level(isolation_level, vm)?; @@ -1512,9 +1509,9 @@ mod _sqlite3 { let _ = unsafe { self.isolation_level.swap(value) }; Ok(()) } - PySetterValue::Delete => Err(vm.new_attribute_error( - "'isolation_level' attribute cannot be deleted".to_owned(), - )), + PySetterValue::Delete => { + Err(vm.new_attribute_error("'isolation_level' attribute cannot be deleted")) + } } } @@ -1629,7 +1626,10 @@ mod _sqlite3 { size: Option, } - #[pyclass(with(Constructor, Initializer, IterNext, Iterable), flags(BASETYPE))] + #[pyclass( + with(Constructor, Initializer, IterNext, Iterable), + flags(BASETYPE, HAS_WEAKREF) + )] impl Cursor { fn new( connection: PyRef, diff --git a/crates/stdlib/src/_tokenize.rs b/crates/stdlib/src/_tokenize.rs new file mode 100644 index 00000000000..13e40ff12b0 --- /dev/null +++ b/crates/stdlib/src/_tokenize.rs @@ -0,0 +1,747 @@ +pub(crate) use _tokenize::module_def; + +#[pymodule] +mod _tokenize { + use crate::{ + common::lock::PyRwLock, + vm::{ + AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, + builtins::{PyBytes, PyStr, PyType}, + convert::ToPyObject, + function::ArgCallable, + protocol::PyIterReturn, + types::{Constructor, IterNext, Iterable, SelfIter}, + }, + }; + use ruff_python_ast::PySourceType; + use ruff_python_ast::token::{Token, TokenKind}; + use ruff_python_parser::{ + LexicalErrorType, ParseError, ParseErrorType, parse_unchecked_source, + }; + use ruff_source_file::{LineIndex, LineRanges}; + use ruff_text_size::{Ranged, TextSize}; + use core::fmt; + + const TOKEN_ENDMARKER: u8 = 0; + const TOKEN_DEDENT: u8 = 6; + const TOKEN_OP: u8 = 55; + const TOKEN_COMMENT: u8 = 65; + const TOKEN_NL: u8 = 66; + + #[pyattr] + #[pyclass(name = "TokenizerIter")] + #[derive(PyPayload)] + pub struct PyTokenizerIter { + readline: ArgCallable, + extra_tokens: bool, + encoding: Option, + state: PyRwLock, + } + + impl PyTokenizerIter { + fn readline(&self, vm: &VirtualMachine) -> PyResult { + let raw_line = match self.readline.invoke((), vm) { + Ok(v) => v, + Err(err) => { + if err.fast_isinstance(vm.ctx.exceptions.stop_iteration) { + return Ok(String::new()); + } + return Err(err); + } + }; + Ok(match &self.encoding { + Some(encoding) => { + let bytes = raw_line + .downcast::() + .map_err(|_| vm.new_type_error("readline() returned a non-bytes object"))?; + vm.state + .codec_registry + .decode_text(bytes.into(), encoding, None, vm) + .map(|s| s.to_string())? + } + None => raw_line + .downcast::() + .map(|s| s.to_string()) + .map_err(|_| vm.new_type_error("readline() returned a non-string object"))?, + }) + } + } + + impl fmt::Debug for PyTokenizerIter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PyTokenizerIter") + .field("extra_tokens", &self.extra_tokens) + .field("encoding", &self.encoding) + .finish() + } + } + + #[pyclass(with(Constructor, Iterable, IterNext))] + impl PyTokenizerIter {} + + impl Constructor for PyTokenizerIter { + type Args = PyTokenizerIterArgs; + + fn py_new(_cls: &Py, args: Self::Args, _vm: &VirtualMachine) -> PyResult { + let Self::Args { + readline, + extra_tokens, + encoding, + } = args; + + Ok(Self { + readline, + extra_tokens, + encoding: encoding.map(|s| s.to_string()), + state: PyRwLock::new(TokenizerState { + phase: TokenizerPhase::Reading { + source: String::new(), + }, + }), + }) + } + } + + impl SelfIter for PyTokenizerIter {} + + impl IterNext for PyTokenizerIter { + fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { + let mut state = zelf.state.read().clone(); + + loop { + match &mut state.phase { + TokenizerPhase::Reading { source } => { + let line = zelf.readline(vm)?; + if line.is_empty() { + let accumulated = core::mem::take(source); + let parsed = + parse_unchecked_source(&accumulated, PySourceType::Python); + let tokens: Vec = parsed.tokens().iter().copied().collect(); + let errors: Vec = parsed.errors().to_vec(); + let line_index = LineIndex::from_source_text(&accumulated); + let implicit_nl = !accumulated.ends_with('\n'); + state.phase = TokenizerPhase::Yielding { + source: accumulated, + tokens, + errors, + index: 0, + line_index, + need_implicit_nl: implicit_nl, + pending_fstring_parts: Vec::new(), + pending_empty_fstring_middle: None, + }; + } else { + source.push_str(&line); + } + } + TokenizerPhase::Yielding { .. } => { + let result = + emit_next_token(&mut state, zelf.extra_tokens, vm)?; + *zelf.state.write() = state; + return Ok(result); + } + TokenizerPhase::Done => { + return Ok(PyIterReturn::StopIteration(None)); + } + } + } + } + } + + /// Emit the next token from the Yielding phase. + fn emit_next_token( + state: &mut TokenizerState, + extra_tokens: bool, + vm: &VirtualMachine, + ) -> PyResult { + let TokenizerPhase::Yielding { + source, + tokens, + errors, + index, + line_index, + need_implicit_nl, + pending_fstring_parts, + pending_empty_fstring_middle, + } = &mut state.phase + else { + unreachable!() + }; + + // Emit pending empty FSTRING_MIDDLE (for format spec nesting) + if let Some((mid_type, mid_line, mid_col, mid_line_str)) = + pending_empty_fstring_middle.take() + { + return Ok(PyIterReturn::Return(make_token_tuple( + vm, + mid_type, + "", + mid_line, + mid_col as isize, + mid_line, + mid_col as isize, + &mid_line_str, + ))); + } + + // Emit any pending fstring sub-tokens first + if let Some((tok_type, tok_str, sl, sc, el, ec)) = pending_fstring_parts.pop() { + let offset: usize = source + .lines() + .take(sl.saturating_sub(1)) + .map(|l| l.len() + 1) + .sum(); + let full_line = + source.full_line_str(TextSize::from(offset.min(source.len()) as u32)); + return Ok(PyIterReturn::Return(make_token_tuple( + vm, tok_type, &tok_str, sl, sc as isize, el, ec as isize, full_line, + ))); + } + + let source_len = TextSize::from(source.len() as u32); + + while *index < tokens.len() { + let token = tokens[*index]; + *index += 1; + let kind = token.kind(); + let range = token.range(); + + // Check for lexical indentation errors. + // Skip when source has tabs — ruff and CPython handle tab + // indentation differently (CPython uses tabsize=8), so ruff may + // report false IndentationErrors for valid mixed-tab code. + if !source.contains('\t') { + for err in errors.iter() { + if !matches!( + err.error, + ParseErrorType::Lexical(LexicalErrorType::IndentationError) + ) { + continue; + } + if err.location.start() <= range.start() + && range.start() < err.location.end() + { + return Err(raise_indentation_error(vm, err, source, line_index)); + } + } + } + + if kind == TokenKind::EndOfFile { + continue; + } + + if !extra_tokens + && matches!(kind, TokenKind::Comment | TokenKind::NonLogicalNewline) + { + continue; + } + + let raw_type = token_kind_value(kind); + let token_type = if extra_tokens && raw_type > TOKEN_DEDENT && raw_type < TOKEN_OP + { + TOKEN_OP + } else { + raw_type + }; + + let (token_str, start_line, start_col, end_line, end_col, line_str) = + if kind == TokenKind::Dedent { + let last_line = source.lines().count(); + let default_pos = if extra_tokens { + (last_line + 1, 0) + } else { + (last_line, 0) + }; + let (pos, dedent_line) = + next_non_dedent_info(tokens, *index, source, line_index, default_pos); + ("", pos.0, pos.1, pos.0, pos.1, dedent_line) + } else { + let start_lc = line_index.line_column(range.start(), source); + let start_line = start_lc.line.get(); + let start_col = start_lc.column.to_zero_indexed(); + let implicit_newline = range.start() >= source_len; + let in_source = range.end() <= source_len; + + let (s, el, ec) = if kind == TokenKind::Newline { + if extra_tokens { + if implicit_newline { + ("", start_line, start_col + 1) + } else { + let s = if source[range].starts_with('\r') { + "\r\n" + } else { + "\n" + }; + (s, start_line, start_col + s.len()) + } + } else { + ("", start_line, start_col) + } + } else if kind == TokenKind::NonLogicalNewline { + let s = if in_source { &source[range] } else { "" }; + (s, start_line, start_col + s.len()) + } else { + let end_lc = line_index.line_column(range.end(), source); + let s = if in_source { &source[range] } else { "" }; + (s, end_lc.line.get(), end_lc.column.to_zero_indexed()) + }; + let line_str = source.full_line_str(range.start()); + (s, start_line, start_col, el, ec, line_str) + }; + + // Handle FSTRING_MIDDLE/TSTRING_MIDDLE brace unescaping + if matches!(kind, TokenKind::FStringMiddle | TokenKind::TStringMiddle) + && (token_str.contains("{{") || token_str.contains("}}")) + { + let mut parts = + split_fstring_middle(token_str, token_type, start_line, start_col) + .into_iter(); + let (tt, ts, sl, sc, el, ec) = parts.next().unwrap(); + let rest: Vec<_> = parts.collect(); + for p in rest.into_iter().rev() { + pending_fstring_parts.push(p); + } + return Ok(PyIterReturn::Return(make_token_tuple( + vm, tt, &ts, sl, sc as isize, el, ec as isize, line_str, + ))); + } + + // After emitting a Rbrace inside an fstring, check if the + // next token is also Rbrace without an intervening FStringMiddle. + // CPython emits an empty FSTRING_MIDDLE in that position. + if kind == TokenKind::Rbrace + && tokens + .get(*index) + .is_some_and(|t| t.kind() == TokenKind::Rbrace) + { + let mid_type = find_fstring_middle_type(tokens, *index); + *pending_empty_fstring_middle = Some(( + mid_type, + end_line, + end_col, + line_str.to_string(), + )); + } + + return Ok(PyIterReturn::Return(make_token_tuple( + vm, token_type, token_str, start_line, start_col as isize, end_line, + end_col as isize, line_str, + ))); + } + + // Emit implicit NL before ENDMARKER if source + // doesn't end with newline and last token is Comment + if extra_tokens && core::mem::take(need_implicit_nl) { + let last_tok = tokens + .iter() + .rev() + .find(|t| t.kind() != TokenKind::EndOfFile); + if let Some(last) = last_tok.filter(|t| t.kind() == TokenKind::Comment) { + let end_lc = line_index.line_column(last.range().end(), source); + let nl_line = end_lc.line.get(); + let nl_col = end_lc.column.to_zero_indexed(); + return Ok(PyIterReturn::Return(make_token_tuple( + vm, + TOKEN_NL, + "", + nl_line, + nl_col as isize, + nl_line, + nl_col as isize + 1, + source.full_line_str(last.range().start()), + ))); + } + } + + // Check for unclosed brackets before ENDMARKER — CPython's tokenizer + // raises SyntaxError("EOF in multi-line statement") in this case. + { + let bracket_count: i32 = tokens + .iter() + .map(|t| match t.kind() { + TokenKind::Lpar | TokenKind::Lsqb | TokenKind::Lbrace => 1, + TokenKind::Rpar | TokenKind::Rsqb | TokenKind::Rbrace => -1, + _ => 0, + }) + .sum(); + if bracket_count > 0 { + let last_line = source.lines().count(); + return Err(raise_syntax_error( + vm, + "EOF in multi-line statement", + last_line + 1, + 0, + )); + } + } + + // All tokens consumed — emit ENDMARKER + let last_line = source.lines().count(); + let (em_line, em_col, em_line_str): (usize, isize, &str) = if extra_tokens { + (last_line + 1, 0, "") + } else { + let last_line_text = source.full_line_str(TextSize::from( + source.len().saturating_sub(1) as u32, + )); + (last_line, -1, last_line_text) + }; + + let result = make_token_tuple( + vm, TOKEN_ENDMARKER, "", em_line, em_col, em_line, em_col, em_line_str, + ); + state.phase = TokenizerPhase::Done; + Ok(PyIterReturn::Return(result)) + } + + /// Determine whether to emit FSTRING_MIDDLE (60) or TSTRING_MIDDLE (63) + /// by looking back for the most recent FStringStart/TStringStart. + fn find_fstring_middle_type(tokens: &[Token], index: usize) -> u8 { + let mut depth = 0i32; + for i in (0..index).rev() { + match tokens[i].kind() { + TokenKind::FStringEnd | TokenKind::TStringEnd => depth += 1, + TokenKind::FStringStart => { + if depth == 0 { + return 60; // FSTRING_MIDDLE + } + depth -= 1; + } + TokenKind::TStringStart => { + if depth == 0 { + return 63; // TSTRING_MIDDLE + } + depth -= 1; + } + _ => {} + } + } + 60 // default to FSTRING_MIDDLE + } + + /// Find the next non-DEDENT token's position and source line. + /// Returns ((line, col), line_str). + fn next_non_dedent_info<'a>( + tokens: &[Token], + index: usize, + source: &'a str, + line_index: &LineIndex, + default_pos: (usize, usize), + ) -> ((usize, usize), &'a str) { + for future in &tokens[index..] { + match future.kind() { + TokenKind::Dedent => continue, + TokenKind::EndOfFile => return (default_pos, ""), + _ => { + let flc = line_index.line_column(future.range().start(), source); + let pos = (flc.line.get(), flc.column.to_zero_indexed()); + return (pos, source.full_line_str(future.range().start())); + } + } + } + (default_pos, "") + } + + /// Raise a SyntaxError with the given message and position. + fn raise_syntax_error( + vm: &VirtualMachine, + msg: &str, + lineno: usize, + offset: usize, + ) -> rustpython_vm::builtins::PyBaseExceptionRef { + let exc = vm.new_exception_msg( + vm.ctx.exceptions.syntax_error.to_owned(), + msg.into(), + ); + let obj = exc.as_object(); + let _ = obj.set_attr("msg", vm.ctx.new_str(msg), vm); + let _ = obj.set_attr("lineno", vm.ctx.new_int(lineno), vm); + let _ = obj.set_attr("offset", vm.ctx.new_int(offset), vm); + let _ = obj.set_attr("filename", vm.ctx.new_str(""), vm); + let _ = obj.set_attr("text", vm.ctx.none(), vm); + exc + } + + /// Raise an IndentationError from a parse error. + fn raise_indentation_error( + vm: &VirtualMachine, + err: &ParseError, + source: &str, + line_index: &LineIndex, + ) -> rustpython_vm::builtins::PyBaseExceptionRef { + let err_lc = line_index.line_column(err.location.start(), source); + let err_line_text = source.full_line_str(err.location.start()); + let err_text = err_line_text.trim_end_matches('\n').trim_end_matches('\r'); + let msg = format!("{}", err.error); + let exc = vm.new_exception_msg( + vm.ctx.exceptions.indentation_error.to_owned(), + msg.clone().into(), + ); + let obj = exc.as_object(); + let _ = obj.set_attr("lineno", vm.ctx.new_int(err_lc.line.get()), vm); + let _ = obj.set_attr("offset", vm.ctx.new_int(err_text.len() as i64 + 1), vm); + let _ = obj.set_attr("msg", vm.ctx.new_str(msg), vm); + let _ = obj.set_attr("filename", vm.ctx.new_str(""), vm); + let _ = obj.set_attr("text", vm.ctx.new_str(err_text), vm); + exc + } + + /// Split an FSTRING_MIDDLE/TSTRING_MIDDLE token containing `{{`/`}}` + /// into multiple unescaped sub-tokens. + /// Returns vec of (type, string, start_line, start_col, end_line, end_col). + fn split_fstring_middle( + raw: &str, + token_type: u8, + start_line: usize, + start_col: usize, + ) -> Vec<(u8, String, usize, usize, usize, usize)> { + let mut parts = Vec::new(); + let mut current = String::new(); + // Track source position (line, col) — these correspond to the + // original source positions (with {{ and }} still doubled) + let mut cur_line = start_line; + let mut cur_col = start_col; + // Track the start position of the current accumulating part + let mut part_start_line = cur_line; + let mut part_start_col = cur_col; + let mut chars = raw.chars().peekable(); + + // Compute end position of the current accumulated text + let end_pos = |current: &str, start_line: usize, start_col: usize| -> (usize, usize) { + let mut el = start_line; + let mut ec = start_col; + for ch in current.chars() { + if ch == '\n' { + el += 1; + ec = 0; + } else { + ec += ch.len_utf8(); + } + } + (el, ec) + }; + + while let Some(ch) = chars.next() { + if ch == '{' && chars.peek() == Some(&'{') { + chars.next(); + current.push('{'); + cur_col += 2; // skip both {{ in source + } else if ch == '}' && chars.peek() == Some(&'}') { + chars.next(); + // Flush accumulated text before }} + if !current.is_empty() { + let (el, ec) = end_pos(¤t, part_start_line, part_start_col); + parts.push(( + token_type, + core::mem::take(&mut current), + part_start_line, + part_start_col, + el, + ec, + )); + } + // Emit unescaped '}' at source position of }} + parts.push(( + token_type, + "}".to_string(), + cur_line, + cur_col, + cur_line, + cur_col + 1, + )); + cur_col += 2; // skip both }} in source + part_start_line = cur_line; + part_start_col = cur_col; + } else { + if current.is_empty() { + part_start_line = cur_line; + part_start_col = cur_col; + } + current.push(ch); + if ch == '\n' { + cur_line += 1; + cur_col = 0; + } else { + cur_col += ch.len_utf8(); + } + } + } + + if !current.is_empty() { + let (el, ec) = end_pos(¤t, part_start_line, part_start_col); + parts.push((token_type, current, part_start_line, part_start_col, el, ec)); + } + + parts + } + + #[allow(clippy::too_many_arguments)] + fn make_token_tuple( + vm: &VirtualMachine, + token_type: u8, + string: &str, + start_line: usize, + start_col: isize, + end_line: usize, + end_col: isize, + line: &str, + ) -> PyObjectRef { + vm.ctx + .new_tuple(vec![ + token_type.to_pyobject(vm), + vm.ctx.new_str(string).into(), + vm.ctx + .new_tuple(vec![start_line.to_pyobject(vm), start_col.to_pyobject(vm)]) + .into(), + vm.ctx + .new_tuple(vec![end_line.to_pyobject(vm), end_col.to_pyobject(vm)]) + .into(), + vm.ctx.new_str(line).into(), + ]) + .into() + } + + #[derive(FromArgs)] + pub struct PyTokenizerIterArgs { + #[pyarg(positional)] + readline: ArgCallable, + #[pyarg(named)] + extra_tokens: bool, + #[pyarg(named, optional)] + encoding: Option>, + } + + #[derive(Clone, Debug)] + struct TokenizerState { + phase: TokenizerPhase, + } + + #[derive(Clone, Debug)] + enum TokenizerPhase { + Reading { + source: String, + }, + Yielding { + source: String, + tokens: Vec, + errors: Vec, + index: usize, + line_index: LineIndex, + need_implicit_nl: bool, + /// Pending sub-tokens from FSTRING_MIDDLE splitting + pending_fstring_parts: Vec<(u8, String, usize, usize, usize, usize)>, + /// Pending empty FSTRING_MIDDLE for format spec nesting: + /// (type, line, col, line_str) + pending_empty_fstring_middle: Option<(u8, usize, usize, String)>, + }, + Done, + } + + const fn token_kind_value(kind: TokenKind) -> u8 { + match kind { + TokenKind::EndOfFile => 0, + TokenKind::Name + | TokenKind::For + | TokenKind::In + | TokenKind::Pass + | TokenKind::Class + | TokenKind::And + | TokenKind::Is + | TokenKind::Raise + | TokenKind::True + | TokenKind::False + | TokenKind::Assert + | TokenKind::Try + | TokenKind::While + | TokenKind::Yield + | TokenKind::Lambda + | TokenKind::None + | TokenKind::Not + | TokenKind::Or + | TokenKind::Break + | TokenKind::Continue + | TokenKind::Global + | TokenKind::Nonlocal + | TokenKind::Return + | TokenKind::Except + | TokenKind::Import + | TokenKind::Case + | TokenKind::Match + | TokenKind::Type + | TokenKind::Await + | TokenKind::With + | TokenKind::Del + | TokenKind::Finally + | TokenKind::From + | TokenKind::Def + | TokenKind::If + | TokenKind::Else + | TokenKind::Elif + | TokenKind::As + | TokenKind::Async => 1, + TokenKind::Int | TokenKind::Complex | TokenKind::Float => 2, + TokenKind::String => 3, + TokenKind::Newline => 4, + TokenKind::NonLogicalNewline => TOKEN_NL, + TokenKind::Indent => 5, + TokenKind::Dedent => 6, + TokenKind::Lpar => 7, + TokenKind::Rpar => 8, + TokenKind::Lsqb => 9, + TokenKind::Rsqb => 10, + TokenKind::Colon => 11, + TokenKind::Comma => 12, + TokenKind::Semi => 13, + TokenKind::Plus => 14, + TokenKind::Minus => 15, + TokenKind::Star => 16, + TokenKind::Slash => 17, + TokenKind::Vbar => 18, + TokenKind::Amper => 19, + TokenKind::Less => 20, + TokenKind::Greater => 21, + TokenKind::Equal => 22, + TokenKind::Dot => 23, + TokenKind::Percent => 24, + TokenKind::Lbrace => 25, + TokenKind::Rbrace => 26, + TokenKind::EqEqual => 27, + TokenKind::NotEqual => 28, + TokenKind::LessEqual => 29, + TokenKind::GreaterEqual => 30, + TokenKind::Tilde => 31, + TokenKind::CircumFlex => 32, + TokenKind::LeftShift => 33, + TokenKind::RightShift => 34, + TokenKind::DoubleStar => 35, + TokenKind::PlusEqual => 36, + TokenKind::MinusEqual => 37, + TokenKind::StarEqual => 38, + TokenKind::SlashEqual => 39, + TokenKind::PercentEqual => 40, + TokenKind::AmperEqual => 41, + TokenKind::VbarEqual => 42, + TokenKind::CircumflexEqual => 43, + TokenKind::LeftShiftEqual => 44, + TokenKind::RightShiftEqual => 45, + TokenKind::DoubleStarEqual => 46, + TokenKind::DoubleSlash => 47, + TokenKind::DoubleSlashEqual => 48, + TokenKind::At => 49, + TokenKind::AtEqual => 50, + TokenKind::Rarrow => 51, + TokenKind::Ellipsis => 52, + TokenKind::ColonEqual => 53, + TokenKind::Exclamation => 54, + TokenKind::FStringStart => 59, + TokenKind::FStringMiddle => 60, + TokenKind::FStringEnd => 61, + TokenKind::Comment => TOKEN_COMMENT, + TokenKind::TStringStart => 62, + TokenKind::TStringMiddle => 63, + TokenKind::TStringEnd => 64, + TokenKind::IpyEscapeCommand + | TokenKind::Question + | TokenKind::Unknown => 67, // ERRORTOKEN + } + } +} diff --git a/crates/stdlib/src/array.rs b/crates/stdlib/src/array.rs index 656b5028623..15a64c6d99c 100644 --- a/crates/stdlib/src/array.rs +++ b/crates/stdlib/src/array.rs @@ -35,7 +35,7 @@ mod array { SaturatedSlice, SequenceIndex, SequenceIndexOp, SliceableSequenceMutOp, SliceableSequenceOp, }, - stdlib::warnings, + stdlib::_warnings, types::{ AsBuffer, AsMapping, AsSequence, Comparable, Constructor, IterNext, Iterable, PyComparisonOp, Representable, SelfIter, @@ -647,7 +647,7 @@ mod array { } if spec == 'u' { - warnings::warn( + _warnings::warn( vm.ctx.exceptions.deprecation_warning, "The 'u' type code is deprecated and will be removed in Python 3.16".to_owned(), 1, @@ -698,7 +698,7 @@ mod array { } #[pyclass( - flags(BASETYPE), + flags(BASETYPE, HAS_WEAKREF), with( Comparable, AsBuffer, diff --git a/crates/stdlib/src/faulthandler.rs b/crates/stdlib/src/faulthandler.rs index f618f8f6731..9c4373c312e 100644 --- a/crates/stdlib/src/faulthandler.rs +++ b/crates/stdlib/src/faulthandler.rs @@ -405,7 +405,7 @@ mod decl { // Get all threads' frame stacks from the shared registry #[cfg(feature = "threading")] { - let current_tid = rustpython_vm::stdlib::thread::get_ident(); + let current_tid = rustpython_vm::stdlib::_thread::get_ident(); let registry = vm.state.thread_frames.lock(); // First dump non-current threads, then current thread last @@ -463,7 +463,7 @@ mod decl { // Install signal handlers if !faulthandler_enable_internal() { - return Err(vm.new_runtime_error("Failed to enable faulthandler".to_owned())); + return Err(vm.new_runtime_error("Failed to enable faulthandler")); } Ok(()) @@ -802,9 +802,7 @@ mod decl { // Check if it's an integer (file descriptor) if let Ok(fd) = f.try_to_value::(vm) { if fd < 0 { - return Err( - vm.new_value_error("file is not a valid file descriptor".to_owned()) - ); + return Err(vm.new_value_error("file is not a valid file descriptor")); } return Ok(fd); } @@ -812,9 +810,7 @@ mod decl { let fileno = vm.call_method(&f, "fileno", ())?; let fd: i32 = fileno.try_to_value(vm)?; if fd < 0 { - return Err( - vm.new_value_error("file is not a valid file descriptor".to_owned()) - ); + return Err(vm.new_value_error("file is not a valid file descriptor")); } // Try to flush the file let _ = vm.call_method(&f, "flush", ()); @@ -824,7 +820,7 @@ mod decl { // file=None or file not passed: fall back to sys.stderr let stderr = vm.sys_module.get_attr("stderr", vm)?; if vm.is_none(&stderr) { - return Err(vm.new_runtime_error("sys.stderr is None".to_owned())); + return Err(vm.new_runtime_error("sys.stderr is None")); } let fileno = vm.call_method(&stderr, "fileno", ())?; let fd: i32 = fileno.try_to_value(vm)?; @@ -912,7 +908,7 @@ mod decl { let timeout: f64 = args.timeout.into_float(); if timeout <= 0.0 { - return Err(vm.new_value_error("timeout must be greater than 0".to_owned())); + return Err(vm.new_value_error("timeout must be greater than 0")); } let fd = get_fd_from_file_opt(args.file, vm)?; @@ -920,7 +916,7 @@ mod decl { // Convert timeout to microseconds let timeout_us = (timeout * 1_000_000.0) as u64; if timeout_us == 0 { - return Err(vm.new_value_error("timeout must be greater than 0".to_owned())); + return Err(vm.new_value_error("timeout must be greater than 0")); } let header = format_timeout(timeout_us); @@ -1098,7 +1094,7 @@ mod decl { // Check if signal is in valid range if !(1..64).contains(&signum) { - return Err(vm.new_value_error("signal number out of range".to_owned())); + return Err(vm.new_value_error("signal number out of range")); } Ok(()) diff --git a/crates/stdlib/src/fcntl.rs b/crates/stdlib/src/fcntl.rs index 407a2dfd6b3..0f75a09ba0f 100644 --- a/crates/stdlib/src/fcntl.rs +++ b/crates/stdlib/src/fcntl.rs @@ -8,7 +8,7 @@ mod fcntl { PyResult, VirtualMachine, builtins::PyIntRef, function::{ArgMemoryBuffer, ArgStrOrBytesLike, Either, OptionalArg}, - stdlib::io, + stdlib::_io, }; // TODO: supply these from (please file an issue/PR upstream): @@ -57,7 +57,7 @@ mod fcntl { #[pyfunction] fn fcntl( - io::Fildes(fd): io::Fildes, + _io::Fildes(fd): _io::Fildes, cmd: i32, arg: OptionalArg>, vm: &VirtualMachine, @@ -91,7 +91,7 @@ mod fcntl { #[pyfunction] fn ioctl( - io::Fildes(fd): io::Fildes, + _io::Fildes(fd): _io::Fildes, request: i64, arg: OptionalArg, i32>>, mutate_flag: OptionalArg, @@ -149,7 +149,7 @@ mod fcntl { // XXX: at the time of writing, wasi and redox don't have the necessary constants/function #[cfg(not(any(target_os = "wasi", target_os = "redox")))] #[pyfunction] - fn flock(io::Fildes(fd): io::Fildes, operation: i32, vm: &VirtualMachine) -> PyResult { + fn flock(_io::Fildes(fd): _io::Fildes, operation: i32, vm: &VirtualMachine) -> PyResult { let ret = unsafe { libc::flock(fd, operation) }; // TODO: add support for platforms that don't have a builtin `flock` syscall if ret < 0 { @@ -162,7 +162,7 @@ mod fcntl { #[cfg(not(any(target_os = "wasi", target_os = "redox")))] #[pyfunction] fn lockf( - io::Fildes(fd): io::Fildes, + _io::Fildes(fd): _io::Fildes, cmd: i32, len: OptionalArg, start: OptionalArg, diff --git a/crates/stdlib/src/hashlib.rs b/crates/stdlib/src/hashlib.rs index 924009884f8..584ed1714d5 100644 --- a/crates/stdlib/src/hashlib.rs +++ b/crates/stdlib/src/hashlib.rs @@ -210,8 +210,7 @@ pub mod _hashlib { (Some(_), Some(_)) => Err(vm.new_type_error( "'data' and 'string' are mutually exclusive \ and support for 'string' keyword parameter \ - is slated for removal in a future version." - .to_owned(), + is slated for removal in a future version.", )), } } @@ -306,7 +305,7 @@ pub mod _hashlib { impl PyHmac { #[pyslot] fn slot_new(_cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { - Err(vm.new_type_error("cannot create '_hashlib.HMAC' instances".to_owned())) + Err(vm.new_type_error("cannot create '_hashlib.HMAC' instances")) } #[pygetset] @@ -758,9 +757,10 @@ pub mod _hashlib { #[pyfunction] fn hmac_new(args: NewHMACHashArgs, vm: &VirtualMachine) -> PyResult { - let digestmod = args.digestmod.into_option().ok_or_else(|| { - vm.new_type_error("Missing required parameter 'digestmod'.".to_owned()) - })?; + let digestmod = args + .digestmod + .into_option() + .ok_or_else(|| vm.new_type_error("Missing required parameter 'digestmod'."))?; let name = resolve_digestmod(&digestmod, vm)?; let key_buf = args.key.borrow_buf(); @@ -833,10 +833,10 @@ pub mod _hashlib { let name = args.hash_name.as_str().to_lowercase(); if args.iterations < 1 { - return Err(vm.new_value_error("iteration value must be greater than 0.".to_owned())); + return Err(vm.new_value_error("iteration value must be greater than 0.")); } let rounds = u32::try_from(args.iterations) - .map_err(|_| vm.new_overflow_error("iteration value is too great.".to_owned()))?; + .map_err(|_| vm.new_overflow_error("iteration value is too great."))?; let dklen: usize = match args.dklen.into_option() { Some(obj) if vm.is_none(&obj) => { @@ -845,10 +845,10 @@ pub mod _hashlib { Some(obj) => { let len: i64 = obj.try_into_value(vm)?; if len < 1 { - return Err(vm.new_value_error("key length must be greater than 0.".to_owned())); + return Err(vm.new_value_error("key length must be greater than 0.")); } usize::try_from(len) - .map_err(|_| vm.new_overflow_error("key length is too great.".to_owned()))? + .map_err(|_| vm.new_overflow_error("key length is too great."))? } None => hash_digest_size(&name).ok_or_else(|| unsupported_hash(&name, vm))?, }; diff --git a/crates/stdlib/src/lib.rs b/crates/stdlib/src/lib.rs index 8c234c22f89..4c06eea9ef4 100644 --- a/crates/stdlib/src/lib.rs +++ b/crates/stdlib/src/lib.rs @@ -49,6 +49,8 @@ mod pystruct; mod random; mod statistics; mod suggestions; +#[path = "_tokenize.rs"] +mod _tokenize; // TODO: maybe make this an extension module, if we ever get those // mod re; #[cfg(all(feature = "host_env", not(target_arch = "wasm32")))] @@ -225,6 +227,7 @@ pub fn stdlib_module_defs(ctx: &Context) -> Vec<&'static builtins::PyModuleDef> ssl::module_def(ctx), statistics::module_def(ctx), suggestions::module_def(ctx), + _tokenize::module_def(ctx), #[cfg(all(feature = "host_env", unix, not(target_os = "redox")))] syslog::module_def(ctx), #[cfg(all( diff --git a/crates/stdlib/src/math.rs b/crates/stdlib/src/math.rs index 80463dcaa22..b071ff5aad7 100644 --- a/crates/stdlib/src/math.rs +++ b/crates/stdlib/src/math.rs @@ -102,13 +102,13 @@ mod math { ))); } if b == 1.0 { - return Err(vm.new_value_error("math domain error".to_owned())); + return Err(vm.new_value_error("math domain error")); } } // Handle BigInt specially for large values (only for actual int type, not float) if let Some(i) = x.downcast_ref::() { return pymath::math::log_bigint(i.as_bigint(), base).map_err(|err| match err { - pymath::Error::EDOM => vm.new_value_error("expected a positive input".to_owned()), + pymath::Error::EDOM => vm.new_value_error("expected a positive input"), _ => pymath_exception(err, vm), }); } @@ -132,7 +132,7 @@ mod math { // Handle BigInt specially for large values (only for actual int type, not float) if let Some(i) = x.downcast_ref::() { return pymath::math::log2_bigint(i.as_bigint()).map_err(|err| match err { - pymath::Error::EDOM => vm.new_value_error("expected a positive input".to_owned()), + pymath::Error::EDOM => vm.new_value_error("expected a positive input"), _ => pymath_exception(err, vm), }); } @@ -151,7 +151,7 @@ mod math { // Handle BigInt specially for large values (only for actual int type, not float) if let Some(i) = x.downcast_ref::() { return pymath::math::log10_bigint(i.as_bigint()).map_err(|err| match err { - pymath::Error::EDOM => vm.new_value_error("expected a positive input".to_owned()), + pymath::Error::EDOM => vm.new_value_error("expected a positive input"), _ => pymath_exception(err, vm), }); } diff --git a/crates/stdlib/src/mmap.rs b/crates/stdlib/src/mmap.rs index c9a6be3b392..d441a3dd887 100644 --- a/crates/stdlib/src/mmap.rs +++ b/crates/stdlib/src/mmap.rs @@ -557,13 +557,11 @@ mod mmap { // Parse tagname: None or a string let tag_str: Option = match tagname { Some(ref obj) if !vm.is_none(obj) => { - let s = obj.try_to_value::(vm).map_err(|_| { - vm.new_type_error("tagname must be a string or None".to_owned()) - })?; + let s = obj + .try_to_value::(vm) + .map_err(|_| vm.new_type_error("tagname must be a string or None"))?; if s.contains('\0') { - return Err(vm.new_value_error( - "tagname must not contain null characters".to_owned(), - )); + return Err(vm.new_value_error("tagname must not contain null characters")); } Some(s) } @@ -851,7 +849,7 @@ mod mmap { #[pyclass( with(Constructor, AsMapping, AsSequence, AsBuffer, Representable), - flags(BASETYPE) + flags(BASETYPE, HAS_WEAKREF) )] impl PyMmap { fn as_bytes_mut(&self) -> BorrowedValueMut<'_, [u8]> { diff --git a/crates/stdlib/src/multiprocessing.rs b/crates/stdlib/src/multiprocessing.rs index fe52cbd19fc..26d1bea8859 100644 --- a/crates/stdlib/src/multiprocessing.rs +++ b/crates/stdlib/src/multiprocessing.rs @@ -154,7 +154,7 @@ mod _multiprocessing { if timeout < 0.0 { 0 } else if timeout >= 0.5 * INFINITE as f64 { - return Err(vm.new_overflow_error("timeout is too large".to_owned())); + return Err(vm.new_overflow_error("timeout is too large")); } else { (timeout + 0.5) as u32 } @@ -236,9 +236,7 @@ mod _multiprocessing { if unsafe { ReleaseSemaphore(self.handle.as_raw(), 1, core::ptr::null_mut()) } == 0 { let err = unsafe { windows_sys::Win32::Foundation::GetLastError() }; if err == ERROR_TOO_MANY_POSTS { - return Err( - vm.new_value_error("semaphore or lock released too many times".to_owned()) - ); + return Err(vm.new_value_error("semaphore or lock released too many times")); } return Err(vm.new_last_os_error()); } @@ -294,7 +292,7 @@ mod _multiprocessing { #[pymethod] fn __reduce__(&self, vm: &VirtualMachine) -> PyResult { - Err(vm.new_type_error("cannot pickle 'SemLock' object".to_owned())) + Err(vm.new_type_error("cannot pickle 'SemLock' object")) } #[pymethod] @@ -338,13 +336,13 @@ mod _multiprocessing { fn py_new(_cls: &Py, args: Self::Args, vm: &VirtualMachine) -> PyResult { if args.kind != RECURSIVE_MUTEX && args.kind != SEMAPHORE { - return Err(vm.new_value_error("unrecognized kind".to_owned())); + return Err(vm.new_value_error("unrecognized kind")); } if args.maxvalue <= 0 { - return Err(vm.new_value_error("maxvalue must be positive".to_owned())); + return Err(vm.new_value_error("maxvalue must be positive")); } if args.value < 0 || args.value > args.maxvalue { - return Err(vm.new_value_error("invalid value".to_owned())); + return Err(vm.new_value_error("invalid value")); } let handle = SemHandle::create(args.value, args.maxvalue, vm)?; @@ -486,7 +484,7 @@ mod _multiprocessing { tv_sec: (delay / 1_000_000) as _, tv_usec: (delay % 1_000_000) as _, }; - unsafe { + vm.allow_threads(|| unsafe { libc::select( 0, core::ptr::null_mut(), @@ -494,7 +492,7 @@ mod _multiprocessing { core::ptr::null_mut(), &mut tv_delay, ) - }; + }); // check for signals - preserve the exception (e.g., KeyboardInterrupt) if let Err(exc) = vm.check_signals() { @@ -708,41 +706,69 @@ mod _multiprocessing { // if (res < 0 && errno == EAGAIN && blocking) if res < 0 && Errno::last() == Errno::EAGAIN && blocking { - // Couldn't acquire immediately, need to block + // Couldn't acquire immediately, need to block. + // + // Save errno inside the allow_threads closure, before + // attach_thread() runs — matches CPython which saves + // `err = errno` before Py_END_ALLOW_THREADS. + #[cfg(not(target_vendor = "apple"))] { + let mut saved_errno; loop { + let sem_ptr = self.handle.as_ptr(); // Py_BEGIN_ALLOW_THREADS / Py_END_ALLOW_THREADS - // RustPython doesn't have GIL, so we just do the wait - if let Some(ref dl) = deadline { - res = unsafe { libc::sem_timedwait(self.handle.as_ptr(), dl) }; + let (r, e) = if let Some(ref dl) = deadline { + vm.allow_threads(|| { + let r = unsafe { libc::sem_timedwait(sem_ptr, dl) }; + ( + r, + if r < 0 { + Errno::last() + } else { + Errno::from_raw(0) + }, + ) + }) } else { - res = unsafe { libc::sem_wait(self.handle.as_ptr()) }; - } + vm.allow_threads(|| { + let r = unsafe { libc::sem_wait(sem_ptr) }; + ( + r, + if r < 0 { + Errno::last() + } else { + Errno::from_raw(0) + }, + ) + }) + }; + res = r; + saved_errno = e; if res >= 0 { break; } - let err = Errno::last(); - if err == Errno::EINTR { + if saved_errno == Errno::EINTR { vm.check_signals()?; continue; } break; } + if res < 0 { + return handle_wait_error(vm, saved_errno); + } } #[cfg(target_vendor = "apple")] { // macOS: use polled fallback since sem_timedwait is not available if let Some(ref dl) = deadline { match sem_timedwait_polled(self.handle.as_ptr(), dl, vm) { - Ok(()) => res = 0, + Ok(()) => {} Err(SemWaitError::Timeout) => { - // Timeout occurred - return false directly return Ok(false); } Err(SemWaitError::SignalException(exc)) => { - // Propagate the original exception (e.g., KeyboardInterrupt) return Err(exc); } Err(SemWaitError::OsError(e)) => { @@ -751,30 +777,42 @@ mod _multiprocessing { } } else { // No timeout: use sem_wait (available on macOS) + let mut saved_errno; loop { - res = unsafe { libc::sem_wait(self.handle.as_ptr()) }; + let sem_ptr = self.handle.as_ptr(); + let (r, e) = vm.allow_threads(|| { + let r = unsafe { libc::sem_wait(sem_ptr) }; + ( + r, + if r < 0 { + Errno::last() + } else { + Errno::from_raw(0) + }, + ) + }); + res = r; + saved_errno = e; if res >= 0 { break; } - let err = Errno::last(); - if err == Errno::EINTR { + if saved_errno == Errno::EINTR { vm.check_signals()?; continue; } break; } + if res < 0 { + return handle_wait_error(vm, saved_errno); + } } } - } - - // result handling: - if res < 0 { + } else if res < 0 { + // Non-blocking path failed, or blocking=false let err = Errno::last(); match err { Errno::EAGAIN | Errno::ETIMEDOUT => return Ok(false), Errno::EINTR => { - // EINTR should be handled by the check_signals() loop above - // If we reach here, check signals again and propagate any exception return vm.check_signals().map(|_| false); } _ => return Err(os_error(vm, err)), @@ -816,9 +854,7 @@ mod _multiprocessing { return Err(os_error(vm, Errno::last())); } if sval >= self.maxvalue { - return Err(vm.new_value_error( - "semaphore or lock released too many times".to_owned(), - )); + return Err(vm.new_value_error("semaphore or lock released too many times")); } } #[cfg(target_vendor = "apple")] @@ -837,9 +873,9 @@ mod _multiprocessing { if unsafe { libc::sem_post(self.handle.as_ptr()) } < 0 { return Err(os_error(vm, Errno::last())); } - return Err(vm.new_value_error( - "semaphore or lock released too many times".to_owned(), - )); + return Err( + vm.new_value_error("semaphore or lock released too many times") + ); } } } @@ -887,7 +923,7 @@ mod _multiprocessing { vm: &VirtualMachine, ) -> PyResult { let Some(ref name_str) = name else { - return Err(vm.new_value_error("cannot rebuild SemLock without name".to_owned())); + return Err(vm.new_value_error("cannot rebuild SemLock without name")); }; let handle = SemHandle::open_existing(name_str, vm)?; // return newsemlockobject(type, handle, kind, maxvalue, name_copy); @@ -915,7 +951,7 @@ mod _multiprocessing { /// Use multiprocessing.synchronize.SemLock wrapper which handles pickling. #[pymethod] fn __reduce__(&self, vm: &VirtualMachine) -> PyResult { - Err(vm.new_type_error("cannot pickle 'SemLock' object".to_owned())) + Err(vm.new_type_error("cannot pickle 'SemLock' object")) } /// Num of `acquire()`s minus num of `release()`s for this process. @@ -1012,11 +1048,11 @@ mod _multiprocessing { // _multiprocessing_SemLock_impl fn py_new(_cls: &Py, args: Self::Args, vm: &VirtualMachine) -> PyResult { if args.kind != RECURSIVE_MUTEX && args.kind != SEMAPHORE { - return Err(vm.new_value_error("unrecognized kind".to_owned())); + return Err(vm.new_value_error("unrecognized kind")); } // Value validation if args.value < 0 || args.value > args.maxvalue { - return Err(vm.new_value_error("invalid value".to_owned())); + return Err(vm.new_value_error("invalid value")); } let value = args.value as u32; @@ -1081,7 +1117,15 @@ mod _multiprocessing { full.push('/'); } full.push_str(name); - CString::new(full).map_err(|_| vm.new_value_error("embedded null character".to_owned())) + CString::new(full).map_err(|_| vm.new_value_error("embedded null character")) + } + + fn handle_wait_error(vm: &VirtualMachine, saved_errno: Errno) -> PyResult { + match saved_errno { + Errno::EAGAIN | Errno::ETIMEDOUT => Ok(false), + Errno::EINTR => vm.check_signals().map(|_| false), + _ => Err(os_error(vm, saved_errno)), + } } fn os_error(vm: &VirtualMachine, err: Errno) -> PyBaseExceptionRef { diff --git a/crates/stdlib/src/openssl.rs b/crates/stdlib/src/openssl.rs index ecea817b436..b00e9306eaf 100644 --- a/crates/stdlib/src/openssl.rs +++ b/crates/stdlib/src/openssl.rs @@ -394,11 +394,11 @@ mod _ssl { let nid = obj.nid(); let short_name = nid .short_name() - .map_err(|_| vm.new_value_error("NID has no short name".to_owned()))? + .map_err(|_| vm.new_value_error("NID has no short name"))? .to_owned(); let long_name = nid .long_name() - .map_err(|_| vm.new_value_error("NID has no long name".to_owned()))? + .map_err(|_| vm.new_value_error("NID has no long name"))? .to_owned(); Ok(( nid.as_raw(), @@ -1135,7 +1135,7 @@ mod _ssl { #[pygetset(setter)] fn set_options(&self, new_opts: i64, vm: &VirtualMachine) -> PyResult<()> { if new_opts < 0 { - return Err(vm.new_value_error("invalid options value".to_owned())); + return Err(vm.new_value_error("invalid options value")); } let new_opts = new_opts as libc::c_ulong; let mut ctx = self.builder(); @@ -1321,14 +1321,12 @@ mod _ssl { fn set_num_tickets(&self, value: isize, vm: &VirtualMachine) -> PyResult<()> { // Check for negative values if value < 0 { - return Err( - vm.new_value_error("num_tickets must be a non-negative integer".to_owned()) - ); + return Err(vm.new_value_error("num_tickets must be a non-negative integer")); } // Check that this is a server context if self.protocol != SslVersion::TlsServer { - return Err(vm.new_value_error("SSLContext is not a server context.".to_owned())); + return Err(vm.new_value_error("SSLContext is not a server context.")); } #[cfg(ossl110)] @@ -1421,7 +1419,7 @@ mod _ssl { } } else { if !callback.is_callable() { - return Err(vm.new_type_error("callback must be callable".to_owned())); + return Err(vm.new_type_error("callback must be callable")); } *self.psk_client_callback.lock() = Some(callback); // Note: The actual callback will be invoked via SSL app_data mechanism @@ -1457,7 +1455,7 @@ mod _ssl { } } else { if !callback.is_callable() { - return Err(vm.new_type_error("callback must be callable".to_owned())); + return Err(vm.new_type_error("callback must be callable")); } *self.psk_server_callback.lock() = Some(callback); if let OptionalArg::Present(hint) = identity_hint { @@ -1588,12 +1586,12 @@ mod _ssl { let store_ptr = unsafe { sys::SSL_CTX_get_cert_store(ctx.as_ptr()) }; if store_ptr.is_null() { - return Err(vm.new_memory_error("failed to get cert store".to_owned())); + return Err(vm.new_memory_error("failed to get cert store")); } let objs_ptr = unsafe { sys::X509_STORE_get0_objects(store_ptr) }; if objs_ptr.is_null() { - return Err(vm.new_memory_error("failed to query cert store".to_owned())); + return Err(vm.new_memory_error("failed to query cert store")); } let mut x509_count = 0; @@ -1727,9 +1725,7 @@ mod _ssl { ) -> PyResult<()> { // Check if this is a server context if self.protocol == SslVersion::TlsClient { - return Err(vm.new_value_error( - "sni_callback cannot be set on TLS_CLIENT context".to_owned(), - )); + return Err(vm.new_value_error("sni_callback cannot be set on TLS_CLIENT context")); } let mut callback_guard = self.sni_callback.lock(); @@ -1738,7 +1734,7 @@ mod _ssl { if !vm.is_none(&callback_obj) { // Check if callable if !callback_obj.is_callable() { - return Err(vm.new_type_error("not a callable object".to_owned())); + return Err(vm.new_type_error("not a callable object")); } // Set the callback @@ -1805,7 +1801,7 @@ mod _ssl { if !vm.is_none(&callback_obj) { // Check if callable if !callback_obj.is_callable() { - return Err(vm.new_type_error("not a callable object".to_owned())); + return Err(vm.new_type_error("not a callable object")); } // Set the callback @@ -2521,7 +2517,7 @@ mod _ssl { unsafe { let result = SSL_set_SSL_CTX(ssl_ptr, value.ctx().as_ptr()); if result.is_null() { - return Err(vm.new_runtime_error("Failed to set SSL context".to_owned())); + return Err(vm.new_runtime_error("Failed to set SSL context")); } } @@ -2806,7 +2802,7 @@ mod _ssl { #[cfg(not(ossl111))] { Err(vm.new_not_implemented_error( - "Post-handshake auth is not supported by your OpenSSL version.".to_owned(), + "Post-handshake auth is not supported by your OpenSSL version.", )) } } @@ -3116,32 +3112,26 @@ mod _ssl { // Check if value is SSLSession type let session = value .downcast_ref::() - .ok_or_else(|| vm.new_type_error("Value is not a SSLSession.".to_owned()))?; + .ok_or_else(|| vm.new_type_error("Value is not a SSLSession."))?; // Check if session refers to the same SSLContext if !std::ptr::eq( self.ctx.read().ctx.read().as_ptr(), session.ctx.ctx.read().as_ptr(), ) { - return Err( - vm.new_value_error("Session refers to a different SSLContext.".to_owned()) - ); + return Err(vm.new_value_error("Session refers to a different SSLContext.")); } // Check if this is a client socket if self.socket_type != SslServerOrClient::Client { - return Err( - vm.new_value_error("Cannot set session for server-side SSLSocket.".to_owned()) - ); + return Err(vm.new_value_error("Cannot set session for server-side SSLSocket.")); } // Check if handshake is not finished let stream = self.connection.read(); unsafe { if sys::SSL_is_init_finished(stream.ssl().as_ptr()) != 0 { - return Err( - vm.new_value_error("Cannot set session after handshake.".to_owned()) - ); + return Err(vm.new_value_error("Cannot set session after handshake.")); } let ret = sys::SSL_set_session(stream.ssl().as_ptr(), session.session); @@ -3182,7 +3172,7 @@ mod _ssl { } OptionalArg::Missing => { if n < 0 { - return Err(vm.new_value_error("size should not be negative".to_owned())); + return Err(vm.new_value_error("size should not be negative")); } n as usize } @@ -3602,7 +3592,7 @@ mod _ssl { unsafe { let bio = sys::BIO_new(sys::BIO_s_mem()); if bio.is_null() { - return Err(vm.new_memory_error("failed to allocate BIO".to_owned())); + return Err(vm.new_memory_error("failed to allocate BIO")); } sys::BIO_set_retry_read(bio); diff --git a/crates/stdlib/src/overlapped.rs b/crates/stdlib/src/overlapped.rs index 1243a1297ea..76d18cb7a9a 100644 --- a/crates/stdlib/src/overlapped.rs +++ b/crates/stdlib/src/overlapped.rs @@ -379,7 +379,7 @@ mod _overlapped { }; Ok((bytes.to_vec(), addr_len)) } - _ => Err(vm.new_value_error("illegal address_as_bytes argument".to_owned())), + _ => Err(vm.new_value_error("illegal address_as_bytes argument")), } } @@ -407,7 +407,7 @@ mod _overlapped { let scope_id = addr.Anonymous.sin6_scope_id; Ok((ip_str, port, flowinfo, scope_id).to_pyobject(vm)) } else { - Err(vm.new_value_error("recvfrom returned unsupported address family".to_owned())) + Err(vm.new_value_error("recvfrom returned unsupported address family")) } } } @@ -473,10 +473,10 @@ mod _overlapped { // Check operation state if matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation not yet attempted".to_owned())); + return Err(vm.new_value_error("operation not yet attempted")); } if matches!(inner.data, OverlappedData::NotStarted) { - return Err(vm.new_value_error("operation failed to start".to_owned())); + return Err(vm.new_value_error("operation failed to start")); } // Get the result @@ -573,7 +573,7 @@ mod _overlapped { let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation already attempted".to_owned())); + return Err(vm.new_value_error("operation already attempted")); } #[cfg(target_pointer_width = "32")] @@ -630,19 +630,19 @@ mod _overlapped { let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation already attempted".to_owned())); + return Err(vm.new_value_error("operation already attempted")); } inner.handle = handle as HANDLE; let buf_len = buf.desc.len; if buf_len > u32::MAX as usize { - return Err(vm.new_value_error("buffer too large".to_owned())); + return Err(vm.new_value_error("buffer too large")); } // For async read, buffer must be contiguous - we can't use a temporary copy // because Windows writes data directly to the buffer after this call returns let Some(contiguous) = buf.as_contiguous_mut() else { - return Err(vm.new_buffer_error("buffer is not contiguous".to_owned())); + return Err(vm.new_buffer_error("buffer is not contiguous")); }; inner.data = OverlappedData::ReadInto(buf.clone()); @@ -694,7 +694,7 @@ mod _overlapped { let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation already attempted".to_owned())); + return Err(vm.new_value_error("operation already attempted")); } let mut flags = flags.unwrap_or(0); @@ -761,18 +761,18 @@ mod _overlapped { let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation already attempted".to_owned())); + return Err(vm.new_value_error("operation already attempted")); } let mut flags = flags; inner.handle = handle as HANDLE; let buf_len = buf.desc.len; if buf_len > u32::MAX as usize { - return Err(vm.new_value_error("buffer too large".to_owned())); + return Err(vm.new_value_error("buffer too large")); } let Some(contiguous) = buf.as_contiguous_mut() else { - return Err(vm.new_buffer_error("buffer is not contiguous".to_owned())); + return Err(vm.new_buffer_error("buffer is not contiguous")); }; inner.data = OverlappedData::ReadInto(buf.clone()); @@ -828,19 +828,19 @@ mod _overlapped { let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation already attempted".to_owned())); + return Err(vm.new_value_error("operation already attempted")); } inner.handle = handle as HANDLE; let buf_len = buf.desc.len; if buf_len > u32::MAX as usize { - return Err(vm.new_value_error("buffer too large".to_owned())); + return Err(vm.new_value_error("buffer too large")); } // For async write, buffer must be contiguous - we can't use a temporary copy // because Windows reads from the buffer after this call returns let Some(contiguous) = buf.as_contiguous() else { - return Err(vm.new_buffer_error("buffer is not contiguous".to_owned())); + return Err(vm.new_buffer_error("buffer is not contiguous")); }; inner.data = OverlappedData::Write(buf.clone()); @@ -886,17 +886,17 @@ mod _overlapped { let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation already attempted".to_owned())); + return Err(vm.new_value_error("operation already attempted")); } inner.handle = handle as HANDLE; let buf_len = buf.desc.len; if buf_len > u32::MAX as usize { - return Err(vm.new_value_error("buffer too large".to_owned())); + return Err(vm.new_value_error("buffer too large")); } let Some(contiguous) = buf.as_contiguous() else { - return Err(vm.new_buffer_error("buffer is not contiguous".to_owned())); + return Err(vm.new_buffer_error("buffer is not contiguous")); }; inner.data = OverlappedData::Write(buf.clone()); @@ -948,7 +948,7 @@ mod _overlapped { let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation already attempted".to_owned())); + return Err(vm.new_value_error("operation already attempted")); } // Buffer size: local address + remote address @@ -1016,7 +1016,7 @@ mod _overlapped { let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation already attempted".to_owned())); + return Err(vm.new_value_error("operation already attempted")); } let (addr_bytes, addr_len) = parse_address(&address, vm)?; @@ -1085,7 +1085,7 @@ mod _overlapped { let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation already attempted".to_owned())); + return Err(vm.new_value_error("operation already attempted")); } inner.handle = socket as HANDLE; @@ -1141,7 +1141,7 @@ mod _overlapped { let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation already attempted".to_owned())); + return Err(vm.new_value_error("operation already attempted")); } inner.handle = socket as HANDLE; @@ -1200,7 +1200,7 @@ mod _overlapped { let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation already attempted".to_owned())); + return Err(vm.new_value_error("operation already attempted")); } inner.handle = pipe as HANDLE; @@ -1243,7 +1243,7 @@ mod _overlapped { let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation already attempted".to_owned())); + return Err(vm.new_value_error("operation already attempted")); } let (addr_bytes, addr_len) = parse_address(&address, vm)?; @@ -1251,11 +1251,11 @@ mod _overlapped { inner.handle = handle as HANDLE; let buf_len = buf.desc.len; if buf_len > u32::MAX as usize { - return Err(vm.new_value_error("buffer too large".to_owned())); + return Err(vm.new_value_error("buffer too large")); } let Some(contiguous) = buf.as_contiguous() else { - return Err(vm.new_buffer_error("buffer is not contiguous".to_owned())); + return Err(vm.new_buffer_error("buffer is not contiguous")); }; // Store both buffer and address in OverlappedData to keep them alive @@ -1322,7 +1322,7 @@ mod _overlapped { let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation already attempted".to_owned())); + return Err(vm.new_value_error("operation already attempted")); } let mut flags = flags.unwrap_or(0); @@ -1410,19 +1410,19 @@ mod _overlapped { let mut inner = zelf.inner.lock(); if !matches!(inner.data, OverlappedData::None) { - return Err(vm.new_value_error("operation already attempted".to_owned())); + return Err(vm.new_value_error("operation already attempted")); } let mut flags = flags.unwrap_or(0); inner.handle = handle as HANDLE; let Some(contiguous) = buf.as_contiguous_mut() else { - return Err(vm.new_buffer_error("buffer is not contiguous".to_owned())); + return Err(vm.new_buffer_error("buffer is not contiguous")); }; let buf_len = buf.desc.len; if buf_len > u32::MAX as usize { - return Err(vm.new_value_error("buffer too large".to_owned())); + return Err(vm.new_value_error("buffer too large")); } let address: SOCKADDR_IN6 = unsafe { core::mem::zeroed() }; @@ -1856,7 +1856,7 @@ mod _overlapped { ) } } else { - return Err(vm.new_value_error("expected tuple of length 2 or 4".to_owned())); + return Err(vm.new_value_error("expected tuple of length 2 or 4")); }; if ret == SOCKET_ERROR { @@ -1944,7 +1944,7 @@ mod _overlapped { vm: &VirtualMachine, ) -> PyResult { if !vm.is_none(&event_attributes) { - return Err(vm.new_value_error("EventAttributes must be None".to_owned())); + return Err(vm.new_value_error("EventAttributes must be None")); } let name_wide: Option> = diff --git a/crates/stdlib/src/re.rs b/crates/stdlib/src/re.rs index fdb14d427fc..c72039f10c5 100644 --- a/crates/stdlib/src/re.rs +++ b/crates/stdlib/src/re.rs @@ -317,7 +317,7 @@ mod re { #[pyfunction] fn purge(_vm: &VirtualMachine) {} - #[pyclass] + #[pyclass(flags(HAS_WEAKREF))] impl PyPattern { #[pymethod(name = "match")] fn match_(&self, text: PyStrRef) -> Option { diff --git a/crates/stdlib/src/select.rs b/crates/stdlib/src/select.rs index 181c4573996..b52144247f7 100644 --- a/crates/stdlib/src/select.rs +++ b/crates/stdlib/src/select.rs @@ -280,7 +280,7 @@ mod decl { loop { let mut tv = timeout.map(sec_to_timeval); - let res = super::select(nfds, &mut r, &mut w, &mut x, tv.as_mut()); + let res = vm.allow_threads(|| super::select(nfds, &mut r, &mut w, &mut x, tv.as_mut())); match res { Ok(_) => break, @@ -337,7 +337,7 @@ mod decl { common::lock::PyMutex, convert::{IntoPyException, ToPyObject}, function::OptionalArg, - stdlib::io::Fildes, + stdlib::_io::Fildes, }; use core::{convert::TryFrom, time::Duration}; use libc::pollfd; @@ -502,7 +502,9 @@ mod decl { let deadline = timeout.map(|d| Instant::now() + d); let mut poll_timeout = timeout_ms; loop { - let res = unsafe { libc::poll(fds.as_mut_ptr(), fds.len() as _, poll_timeout) }; + let res = vm.allow_threads(|| unsafe { + libc::poll(fds.as_mut_ptr(), fds.len() as _, poll_timeout) + }); match nix::Error::result(res) { Ok(_) => break, Err(nix::Error::EINTR) => vm.check_signals()?, @@ -552,7 +554,7 @@ mod decl { common::lock::{PyRwLock, PyRwLockReadGuard}, convert::{IntoPyException, ToPyObject}, function::OptionalArg, - stdlib::io::Fildes, + stdlib::_io::Fildes, types::Constructor, }; use core::ops::Deref; @@ -695,11 +697,13 @@ mod decl { loop { events.clear(); - match epoll::wait( - epoll, - rustix::buffer::spare_capacity(&mut events), - poll_timeout.as_ref(), - ) { + match vm.allow_threads(|| { + epoll::wait( + epoll, + rustix::buffer::spare_capacity(&mut events), + poll_timeout.as_ref(), + ) + }) { Ok(_) => break, Err(rustix::io::Errno::INTR) => vm.check_signals()?, Err(e) => return Err(e.into_pyexception(vm)), diff --git a/crates/stdlib/src/socket.rs b/crates/stdlib/src/socket.rs index dce1f27d1ce..cecfbed4298 100644 --- a/crates/stdlib/src/socket.rs +++ b/crates/stdlib/src/socket.rs @@ -1105,7 +1105,8 @@ mod _socket { loop { if deadline.is_some() || matches!(select, SelectKind::Connect) { let interval = deadline.as_ref().map(|d| d.time_until()).transpose()?; - let res = sock_select(&*self.sock()?, select, interval); + let sock = self.sock()?; + let res = vm.allow_threads(|| sock_select(&sock, select, interval)); match res { Ok(true) => return Err(IoOrPyException::Timeout), Err(e) if e.kind() == io::ErrorKind::Interrupted => { @@ -1118,8 +1119,9 @@ mod _socket { } let err = loop { - // loop on interrupt - match f() { + // Detach thread state around the blocking syscall so + // stop-the-world can park this thread (e.g. before fork). + match vm.allow_threads(&mut f) { Ok(x) => return Ok(x), Err(e) if e.kind() == io::ErrorKind::Interrupted => vm.check_signals()?, Err(e) => break e, @@ -1300,10 +1302,10 @@ mod _socket { // salg_type is 14 bytes, salg_name is 64 bytes if type_str.len() >= 14 { - return Err(vm.new_value_error("type too long".to_owned()).into()); + return Err(vm.new_value_error("type too long").into()); } if name_str.len() >= 64 { - return Err(vm.new_value_error("name too long".to_owned()).into()); + return Err(vm.new_value_error("name too long").into()); } // Create sockaddr_alg @@ -1342,7 +1344,8 @@ mod _socket { ) -> Result<(), IoOrPyException> { let sock_addr = self.extract_address(address, caller, vm)?; - let err = match self.sock()?.connect(&sock_addr) { + let sock = self.sock()?; + let err = match vm.allow_threads(|| sock.connect(&sock_addr)) { Ok(()) => return Ok(()), Err(e) => e, }; @@ -1381,13 +1384,20 @@ mod _socket { impl DefaultConstructor for PySocket {} + #[derive(FromArgs)] + pub struct SocketInitArgs { + #[pyarg(any, optional)] + family: OptionalArg, + #[pyarg(any, optional)] + r#type: OptionalArg, + #[pyarg(any, optional)] + proto: OptionalArg, + #[pyarg(any, optional)] + fileno: OptionalOption, + } + impl Initializer for PySocket { - type Args = ( - OptionalArg, - OptionalArg, - OptionalArg, - OptionalOption, - ); + type Args = SocketInitArgs; fn init(zelf: PyRef, args: Self::Args, vm: &VirtualMachine) -> PyResult<()> { Self::_init(zelf, args, vm).map_err(|e| e.into_pyexception(vm)) @@ -1411,13 +1421,14 @@ mod _socket { impl PySocket { fn _init( zelf: PyRef, - (family, socket_kind, proto, fileno): ::Args, + args: ::Args, vm: &VirtualMachine, ) -> Result<(), IoOrPyException> { - let mut family = family.unwrap_or(-1); - let mut socket_kind = socket_kind.unwrap_or(-1); - let mut proto = proto.unwrap_or(-1); + let mut family = args.family.unwrap_or(-1); + let mut socket_kind = args.r#type.unwrap_or(-1); + let mut proto = args.proto.unwrap_or(-1); + let fileno = args.fileno; let sock; // On Windows, fileno can be bytes from socket.share() for fromshare() @@ -1627,9 +1638,9 @@ mod _socket { // Handle nbytes parameter let read_len = if let OptionalArg::Present(nbytes) = nbytes { - let nbytes = nbytes.to_usize().ok_or_else(|| { - vm.new_value_error("negative buffersize in recv_into".to_owned()) - })?; + let nbytes = nbytes + .to_usize() + .ok_or_else(|| vm.new_value_error("negative buffersize in recv_into"))?; nbytes.min(buf.len()) } else { buf.len() @@ -1836,7 +1847,7 @@ mod _socket { // Validate assoclen - must be non-negative if provided let assoclen: Option = match args.assoclen { OptionalArg::Present(val) if val < 0 => { - return Err(vm.new_type_error("assoclen must be non-negative".to_owned())); + return Err(vm.new_type_error("assoclen must be non-negative")); } OptionalArg::Present(val) => Some(val as u32), OptionalArg::Missing => None, @@ -1955,15 +1966,13 @@ mod _socket { use core::mem::MaybeUninit; if bufsize < 0 { - return Err(vm.new_value_error("negative buffer size in recvmsg".to_owned())); + return Err(vm.new_value_error("negative buffer size in recvmsg")); } let bufsize = bufsize as usize; let ancbufsize = ancbufsize.unwrap_or(0); if ancbufsize < 0 { - return Err( - vm.new_value_error("negative ancillary buffer size in recvmsg".to_owned()) - ); + return Err(vm.new_value_error("negative ancillary buffer size in recvmsg")); } let ancbufsize = ancbufsize as usize; let flags = flags.unwrap_or(0); @@ -2214,12 +2223,10 @@ mod _socket { Some(t) => { let f = t.into_float(); if f.is_nan() { - return Err( - vm.new_value_error("Invalid value NaN (not a number)".to_owned()) - ); + return Err(vm.new_value_error("Invalid value NaN (not a number)")); } if f < 0.0 || !f.is_finite() { - return Err(vm.new_value_error("Timeout value out of range".to_owned())); + return Err(vm.new_value_error("Timeout value out of range")); } Some(f) } @@ -2846,14 +2853,13 @@ mod _socket { .codec_registry .encode_text(s.to_owned(), "idna", None, vm)?; let host_str = core::str::from_utf8(encoded.as_bytes()) - .map_err(|_| vm.new_runtime_error("idna output is not utf8".to_owned()))?; + .map_err(|_| vm.new_runtime_error("idna output is not utf8"))?; Some(host_str.to_owned()) } Some(ArgStrOrBytesLike::Buf(b)) => { let bytes = b.borrow_buf(); - let host_str = core::str::from_utf8(&bytes).map_err(|_| { - vm.new_unicode_decode_error("host bytes is not utf8".to_owned()) - })?; + let host_str = core::str::from_utf8(&bytes) + .map_err(|_| vm.new_unicode_decode_error("host bytes is not utf8"))?; Some(host_str.to_owned()) } None => None, @@ -2874,9 +2880,7 @@ mod _socket { // For bytes, check if it's valid UTF-8 let bytes = b.borrow_buf(); core::str::from_utf8(&bytes) - .map_err(|_| { - vm.new_unicode_decode_error("port is not utf8".to_owned()) - })? + .map_err(|_| vm.new_unicode_decode_error("port is not utf8"))? .to_owned() } }; @@ -3396,10 +3400,10 @@ mod _socket { Some(t) => { let f = t.into_float(); if f.is_nan() { - return Err(vm.new_value_error("Invalid value NaN (not a number)".to_owned())); + return Err(vm.new_value_error("Invalid value NaN (not a number)")); } if f < 0.0 || !f.is_finite() { - return Err(vm.new_value_error("Timeout value out of range".to_owned())); + return Err(vm.new_value_error("Timeout value out of range")); } f } diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index adf9e9526f1..399c703faa9 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -48,7 +48,7 @@ mod _ssl { function::{ ArgBytesLike, ArgMemoryBuffer, Either, FuncArgs, OptionalArg, PyComparisonValue, }, - stdlib::warnings, + stdlib::_warnings, types::{Comparable, Constructor, Hashable, PyComparisonOp, Representable}, }, }; @@ -959,7 +959,7 @@ mod _ssl { fn set_options(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> { // Validate that the value is non-negative if value < 0 { - return Err(vm.new_value_error("options must be non-negative".to_owned())); + return Err(vm.new_value_error("options must be non-negative")); } // Deprecated SSL/TLS protocol version options @@ -976,7 +976,7 @@ mod _ssl { // Warn if any deprecated options are being newly set if (set & opt_no) != 0 { - warnings::warn( + _warnings::warn( vm.ctx.exceptions.deprecation_warning, "ssl.OP_NO_SSL*/ssl.OP_NO_TLS* options are deprecated".to_owned(), 2, // stack_level = 2 @@ -1094,14 +1094,12 @@ mod _ssl { pwd_str.as_str().to_owned() } else if let Ok(pwd_bytes_like) = ArgBytesLike::try_from_object(vm, pwd_result) { String::from_utf8(pwd_bytes_like.borrow_buf().to_vec()).map_err(|_| { - vm.new_type_error( - "password callback returned invalid UTF-8 bytes".to_owned(), - ) + vm.new_type_error("password callback returned invalid UTF-8 bytes") })? } else { - return Err(vm.new_type_error( - "password callback must return a string or bytes".to_owned(), - )); + return Err( + vm.new_type_error("password callback must return a string or bytes") + ); }; // Validate callable password length @@ -1244,9 +1242,7 @@ mod _ssl { let has_cadata = matches!(&args.cadata, OptionalArg::Present(Some(_))); if !has_cafile && !has_capath && !has_cadata { - return Err( - vm.new_type_error("cafile, capath and cadata cannot be all omitted".to_owned()) - ); + return Err(vm.new_type_error("cafile, capath and cadata cannot be all omitted")); } // Parse arguments BEFORE acquiring locks to reduce lock scope @@ -1745,7 +1741,7 @@ mod _ssl { fn load_dh_params(&self, filepath: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { // Validate filepath is not None if vm.is_none(&filepath) { - return Err(vm.new_type_error("DH params filepath cannot be None".to_owned())); + return Err(vm.new_type_error("DH params filepath cannot be None")); } // Validate filepath is str or bytes @@ -1753,9 +1749,9 @@ mod _ssl { s.as_str().to_owned() } else if let Ok(b) = ArgBytesLike::try_from_object(vm, filepath) { String::from_utf8(b.borrow_buf().to_vec()) - .map_err(|_| vm.new_value_error("Invalid path encoding".to_owned()))? + .map_err(|_| vm.new_value_error("Invalid path encoding"))? } else { - return Err(vm.new_type_error("DH params filepath must be str or bytes".to_owned())); + return Err(vm.new_type_error("DH params filepath must be str or bytes")); }; // Check if file exists @@ -1800,7 +1796,7 @@ mod _ssl { fn set_ecdh_curve(&self, name: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { // Validate name is not None if vm.is_none(&name) { - return Err(vm.new_type_error("ECDH curve name cannot be None".to_owned())); + return Err(vm.new_type_error("ECDH curve name cannot be None")); } // Validate name is str or bytes @@ -1808,9 +1804,9 @@ mod _ssl { s.as_str().to_owned() } else if let Ok(b) = ArgBytesLike::try_from_object(vm, name) { String::from_utf8(b.borrow_buf().to_vec()) - .map_err(|_| vm.new_value_error("Invalid curve name encoding".to_owned()))? + .map_err(|_| vm.new_value_error("Invalid curve name encoding"))? } else { - return Err(vm.new_type_error("ECDH curve name must be str or bytes".to_owned())); + return Err(vm.new_type_error("ECDH curve name must be str or bytes")); }; // Validate curve name (common curves for compatibility) @@ -2005,7 +2001,7 @@ mod _ssl { match arg { Either::A(s) => Ok(s.clone().try_into_utf8(vm)?.as_str().to_owned()), Either::B(b) => String::from_utf8(b.borrow_buf().to_vec()) - .map_err(|_| vm.new_value_error("path contains invalid UTF-8".to_owned())), + .map_err(|_| vm.new_value_error("path contains invalid UTF-8")), } } diff --git a/crates/vm/src/builtins/asyncgenerator.rs b/crates/vm/src/builtins/asyncgenerator.rs index 6b7cfea9c29..dcb1c6d6f81 100644 --- a/crates/vm/src/builtins/asyncgenerator.rs +++ b/crates/vm/src/builtins/asyncgenerator.rs @@ -40,7 +40,10 @@ impl PyPayload for PyAsyncGen { } } -#[pyclass(flags(DISALLOW_INSTANTIATION), with(PyRef, Representable, Destructor))] +#[pyclass( + flags(DISALLOW_INSTANTIATION, HAS_WEAKREF), + with(PyRef, Representable, Destructor) +)] impl PyAsyncGen { pub const fn as_coro(&self) -> &Coro { &self.inner diff --git a/crates/vm/src/builtins/builtin_func.rs b/crates/vm/src/builtins/builtin_func.rs index bc72b1ad533..1326febd000 100644 --- a/crates/vm/src/builtins/builtin_func.rs +++ b/crates/vm/src/builtins/builtin_func.rs @@ -16,6 +16,8 @@ pub struct PyNativeFunction { pub(crate) value: &'static PyMethodDef, pub(crate) zelf: Option, pub(crate) module: Option<&'static PyStrInterned>, // None for bound method + /// Prevent HeapMethodDef from being freed while this function references it + pub(crate) _method_def_owner: Option, } impl PyPayload for PyNativeFunction { @@ -126,7 +128,7 @@ impl Representable for PyNativeFunction { #[pyclass( with(Callable, Comparable, Representable), - flags(HAS_DICT, DISALLOW_INSTANTIATION) + flags(HAS_DICT, HAS_WEAKREF, DISALLOW_INSTANTIATION) )] impl PyNativeFunction { #[pygetset] @@ -210,7 +212,7 @@ pub struct PyNativeMethod { // All Python-visible behavior (getters, slots) is registered by PyNativeFunction::extend_class. // PyNativeMethod only extends the Rust-side struct with the defining class reference. // The func field at offset 0 (#[repr(C)]) allows NativeFunctionOrMethod to read it safely. -#[pyclass(flags(HAS_DICT, DISALLOW_INSTANTIATION))] +#[pyclass(flags(HAS_DICT, HAS_WEAKREF, DISALLOW_INSTANTIATION))] impl PyNativeMethod {} impl fmt::Debug for PyNativeMethod { diff --git a/crates/vm/src/builtins/bytearray.rs b/crates/vm/src/builtins/bytearray.rs index 82a283ec429..dec5cacc972 100644 --- a/crates/vm/src/builtins/bytearray.rs +++ b/crates/vm/src/builtins/bytearray.rs @@ -541,7 +541,7 @@ impl PyByteArray { #[pymethod] fn resize(&self, size: isize, vm: &VirtualMachine) -> PyResult<()> { if size < 0 { - return Err(vm.new_value_error("bytearray.resize(): new size must be >= 0".to_owned())); + return Err(vm.new_value_error("bytearray.resize(): new size must be >= 0")); } self.try_resizable(vm)?.elements.resize(size as usize, 0); Ok(()) diff --git a/crates/vm/src/builtins/classmethod.rs b/crates/vm/src/builtins/classmethod.rs index 3ec1085abc4..f42bdcc23d2 100644 --- a/crates/vm/src/builtins/classmethod.rs +++ b/crates/vm/src/builtins/classmethod.rs @@ -125,7 +125,7 @@ impl PyClassMethod { #[pyclass( with(GetDescriptor, Constructor, Initializer, Representable), - flags(BASETYPE, HAS_DICT) + flags(BASETYPE, HAS_DICT, HAS_WEAKREF) )] impl PyClassMethod { #[pygetset] diff --git a/crates/vm/src/builtins/code.rs b/crates/vm/src/builtins/code.rs index 3a9ccc35637..a7ef4c08a2d 100644 --- a/crates/vm/src/builtins/code.rs +++ b/crates/vm/src/builtins/code.rs @@ -479,9 +479,9 @@ impl Constructor for PyCode { .names .iter() .map(|obj| { - let s = obj.downcast_ref::().ok_or_else(|| { - vm.new_type_error("names must be tuple of strings".to_owned()) - })?; + let s = obj + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("names must be tuple of strings"))?; Ok(vm.ctx.intern_str(s.as_wtf8())) }) .collect::>>()? @@ -491,9 +491,9 @@ impl Constructor for PyCode { .varnames .iter() .map(|obj| { - let s = obj.downcast_ref::().ok_or_else(|| { - vm.new_type_error("varnames must be tuple of strings".to_owned()) - })?; + let s = obj + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("varnames must be tuple of strings"))?; Ok(vm.ctx.intern_str(s.as_wtf8())) }) .collect::>>()? @@ -503,9 +503,9 @@ impl Constructor for PyCode { .cellvars .iter() .map(|obj| { - let s = obj.downcast_ref::().ok_or_else(|| { - vm.new_type_error("cellvars must be tuple of strings".to_owned()) - })?; + let s = obj + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("cellvars must be tuple of strings"))?; Ok(vm.ctx.intern_str(s.as_wtf8())) }) .collect::>>()? @@ -515,9 +515,9 @@ impl Constructor for PyCode { .freevars .iter() .map(|obj| { - let s = obj.downcast_ref::().ok_or_else(|| { - vm.new_type_error("freevars must be tuple of strings".to_owned()) - })?; + let s = obj + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("freevars must be tuple of strings"))?; Ok(vm.ctx.intern_str(s.as_wtf8())) }) .collect::>>()? @@ -538,16 +538,14 @@ impl Constructor for PyCode { .map_err(|e| vm.new_value_error(format!("invalid bytecode: {}", e)))?; // Convert constants - let constants: Box<[Literal]> = args + let constants = args .consts .iter() .map(|obj| { - // Convert PyObject to Literal constant - // For now, just wrap it + // Convert PyObject to Literal constant. For now, just wrap it Literal(obj.clone()) }) - .collect::>() - .into_boxed_slice(); + .collect(); // Create locations (start and end pairs) let row = if args.firstlineno > 0 { @@ -597,7 +595,7 @@ impl Constructor for PyCode { } } -#[pyclass(with(Representable, Constructor))] +#[pyclass(with(Representable, Constructor), flags(HAS_WEAKREF))] impl PyCode { #[pygetset] const fn co_posonlyargcount(&self) -> usize { diff --git a/crates/vm/src/builtins/complex.rs b/crates/vm/src/builtins/complex.rs index f05e5a32faa..b3425d2aac1 100644 --- a/crates/vm/src/builtins/complex.rs +++ b/crates/vm/src/builtins/complex.rs @@ -7,7 +7,7 @@ use crate::{ convert::{IntoPyException, ToPyObject, ToPyResult}, function::{FuncArgs, OptionalArg, PyComparisonValue}, protocol::PyNumberMethods, - stdlib::warnings, + stdlib::_warnings, types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp, Representable}, }; use core::cell::Cell; @@ -58,7 +58,7 @@ impl PyPayload for PyComplex { } #[inline] - unsafe fn freelist_pop() -> Option> { + unsafe fn freelist_pop(_payload: &Self) -> Option> { COMPLEX_FREELIST .try_with(|fl| { let mut list = fl.take(); @@ -95,7 +95,7 @@ impl PyObjectRef { let ret_class = result.class().to_owned(); if let Some(ret) = result.downcast_ref::() { - warnings::warn( + _warnings::warn( vm.ctx.exceptions.deprecation_warning, format!( "__complex__ returned non-complex (type {ret_class}). \ @@ -321,8 +321,15 @@ impl PyComplex { if spec.is_empty() { return Ok(zelf.as_object().str(vm)?.as_wtf8().to_owned()); } - FormatSpec::parse(spec.as_str()) - .and_then(|format_spec| format_spec.format_complex(&zelf.value)) + let format_spec = + FormatSpec::parse(spec.as_str()).map_err(|err| err.into_pyexception(vm))?; + let result = if format_spec.has_locale_format() { + let locale = crate::format::get_locale_info(); + format_spec.format_complex_locale(&zelf.value, &locale) + } else { + format_spec.format_complex(&zelf.value) + }; + result .map(Wtf8Buf::from_string) .map_err(|err| err.into_pyexception(vm)) } @@ -411,7 +418,7 @@ impl AsNumber for PyComplex { let result = value.norm(); // Check for overflow: hypot returns inf for finite inputs that overflow if result.is_infinite() && value.re.is_finite() && value.im.is_finite() { - return Err(vm.new_overflow_error("absolute value too large".to_owned())); + return Err(vm.new_overflow_error("absolute value too large")); } result.to_pyresult(vm) }), diff --git a/crates/vm/src/builtins/coroutine.rs b/crates/vm/src/builtins/coroutine.rs index 5b29570b2f8..9746dddda87 100644 --- a/crates/vm/src/builtins/coroutine.rs +++ b/crates/vm/src/builtins/coroutine.rs @@ -32,7 +32,7 @@ impl PyPayload for PyCoroutine { } #[pyclass( - flags(DISALLOW_INSTANTIATION), + flags(DISALLOW_INSTANTIATION, HAS_WEAKREF), with(Py, IterNext, Representable, Destructor) )] impl PyCoroutine { diff --git a/crates/vm/src/builtins/descriptor.rs b/crates/vm/src/builtins/descriptor.rs index 05e819a56e9..acfe58d723b 100644 --- a/crates/vm/src/builtins/descriptor.rs +++ b/crates/vm/src/builtins/descriptor.rs @@ -37,6 +37,8 @@ pub struct PyMethodDescriptor { pub method: &'static PyMethodDef, // vectorcall: vector_call_func, pub objclass: &'static Py, // TODO: move to tp_members + /// Prevent HeapMethodDef from being freed while this descriptor references it + pub(crate) _method_def_owner: Option, } impl PyMethodDescriptor { @@ -49,6 +51,7 @@ impl PyMethodDescriptor { }, method, objclass: typ, + _method_def_owner: None, } } } @@ -88,13 +91,12 @@ impl GetDescriptor for PyMethodDescriptor { } else if descr.method.flags.contains(PyMethodFlags::CLASS) { obj.class().to_owned().into() } else { - unimplemented!() + obj } } None if descr.method.flags.contains(PyMethodFlags::CLASS) => cls.unwrap(), None => return Ok(zelf), }; - // Ok(descr.method.build_bound_method(&vm.ctx, bound, class).into()) Ok(descr.bind(bound, &vm.ctx).into()) } } @@ -370,7 +372,7 @@ fn set_slot_at_object( obj.set_slot(offset, Some(v)) } PySetterValue::Delete => { - return Err(vm.new_type_error("can't delete numeric/char attribute".to_owned())); + return Err(vm.new_type_error("can't delete numeric/char attribute")); } }; } @@ -590,9 +592,7 @@ impl SlotFunc { } SlotFunc::Hash(func) => { if !args.args.is_empty() || !args.kwargs.is_empty() { - return Err( - vm.new_type_error("__hash__() takes no arguments (1 given)".to_owned()) - ); + return Err(vm.new_type_error("__hash__() takes no arguments (1 given)")); } let hash = func(&obj, vm)?; Ok(vm.ctx.new_int(hash).into()) @@ -611,26 +611,20 @@ impl SlotFunc { } SlotFunc::Iter(func) => { if !args.args.is_empty() || !args.kwargs.is_empty() { - return Err( - vm.new_type_error("__iter__() takes no arguments (1 given)".to_owned()) - ); + return Err(vm.new_type_error("__iter__() takes no arguments (1 given)")); } func(obj, vm) } SlotFunc::IterNext(func) => { if !args.args.is_empty() || !args.kwargs.is_empty() { - return Err( - vm.new_type_error("__next__() takes no arguments (1 given)".to_owned()) - ); + return Err(vm.new_type_error("__next__() takes no arguments (1 given)")); } func(&obj, vm).to_pyresult(vm) } SlotFunc::Call(func) => func(&obj, args, vm), SlotFunc::Del(func) => { if !args.args.is_empty() || !args.kwargs.is_empty() { - return Err( - vm.new_type_error("__del__() takes no arguments (1 given)".to_owned()) - ); + return Err(vm.new_type_error("__del__() takes no arguments (1 given)")); } func(&obj, vm)?; Ok(vm.ctx.none()) diff --git a/crates/vm/src/builtins/dict.rs b/crates/vm/src/builtins/dict.rs index f2a7e6a5a29..0e64e9e66ac 100644 --- a/crates/vm/src/builtins/dict.rs +++ b/crates/vm/src/builtins/dict.rs @@ -93,7 +93,7 @@ impl PyPayload for PyDict { } #[inline] - unsafe fn freelist_pop() -> Option> { + unsafe fn freelist_pop(_payload: &Self) -> Option> { DICT_FREELIST .try_with(|fl| { let mut list = fl.take(); diff --git a/crates/vm/src/builtins/float.rs b/crates/vm/src/builtins/float.rs index e9267a9bf00..eeddd6b2eb9 100644 --- a/crates/vm/src/builtins/float.rs +++ b/crates/vm/src/builtins/float.rs @@ -65,7 +65,7 @@ impl PyPayload for PyFloat { } #[inline] - unsafe fn freelist_pop() -> Option> { + unsafe fn freelist_pop(_payload: &Self) -> Option> { FLOAT_FREELIST .try_with(|fl| { let mut list = fl.take(); @@ -259,8 +259,15 @@ impl PyFloat { if spec.is_empty() { return Ok(zelf.as_object().str(vm)?.as_wtf8().to_owned()); } - FormatSpec::parse(spec.as_str()) - .and_then(|format_spec| format_spec.format_float(zelf.value)) + let format_spec = + FormatSpec::parse(spec.as_str()).map_err(|err| err.into_pyexception(vm))?; + let result = if format_spec.has_locale_format() { + let locale = crate::format::get_locale_info(); + format_spec.format_float_locale(zelf.value, &locale) + } else { + format_spec.format_float(zelf.value) + }; + result .map(Wtf8Buf::from_string) .map_err(|err| err.into_pyexception(vm)) } diff --git a/crates/vm/src/builtins/frame.rs b/crates/vm/src/builtins/frame.rs index 4a8549239e7..4601eee4467 100644 --- a/crates/vm/src/builtins/frame.rs +++ b/crates/vm/src/builtins/frame.rs @@ -488,13 +488,13 @@ impl Frame { PySetterValue::Assign(val) => { let line_ref: PyIntRef = val .downcast() - .map_err(|_| vm.new_value_error("lineno must be an integer".to_owned()))?; + .map_err(|_| vm.new_value_error("lineno must be an integer"))?; line_ref .try_to_primitive::(vm) - .map_err(|_| vm.new_value_error("lineno must be an integer".to_owned()))? + .map_err(|_| vm.new_value_error("lineno must be an integer"))? } PySetterValue::Delete => { - return Err(vm.new_type_error("can't delete f_lineno attribute".to_owned())); + return Err(vm.new_type_error("can't delete f_lineno attribute")); } }; @@ -677,12 +677,12 @@ impl Py { // FRAME_SUSPENDED). lasti == 0 means FRAME_CREATED and // can be cleared. if self.lasti() != 0 { - return Err(vm.new_runtime_error("cannot clear a suspended frame".to_owned())); + return Err(vm.new_runtime_error("cannot clear a suspended frame")); } } FrameOwner::Thread => { // Thread-owned frame: always executing, cannot clear. - return Err(vm.new_runtime_error("cannot clear an executing frame".to_owned())); + return Err(vm.new_runtime_error("cannot clear an executing frame")); } FrameOwner::FrameObject => { // Detached frame: safe to clear. diff --git a/crates/vm/src/builtins/function.rs b/crates/vm/src/builtins/function.rs index 03663d22e5d..7a6a3ef278f 100644 --- a/crates/vm/src/builtins/function.rs +++ b/crates/vm/src/builtins/function.rs @@ -13,7 +13,7 @@ use crate::{ bytecode, class::PyClassImpl, common::wtf8::{Wtf8Buf, wtf8_concat}, - frame::Frame, + frame::{Frame, FrameRef}, function::{FuncArgs, OptionalArg, PyComparisonValue, PySetterValue}, scope::Scope, types::{ @@ -454,81 +454,87 @@ impl PyFunction { /// Set function attribute based on MakeFunctionFlags pub(crate) fn set_function_attribute( &mut self, - attr: bytecode::MakeFunctionFlags, + attr: bytecode::MakeFunctionFlag, attr_value: PyObjectRef, vm: &VirtualMachine, ) -> PyResult<()> { use crate::builtins::PyDict; - if attr == bytecode::MakeFunctionFlags::DEFAULTS { - let defaults = match attr_value.downcast::() { - Ok(tuple) => tuple, - Err(obj) => { - return Err(vm.new_type_error(format!( - "__defaults__ must be a tuple, not {}", - obj.class().name() - ))); - } - }; - self.defaults_and_kwdefaults.lock().0 = Some(defaults); - } else if attr == bytecode::MakeFunctionFlags::KW_ONLY_DEFAULTS { - let kwdefaults = match attr_value.downcast::() { - Ok(dict) => dict, - Err(obj) => { - return Err(vm.new_type_error(format!( - "__kwdefaults__ must be a dict, not {}", - obj.class().name() - ))); - } - }; - self.defaults_and_kwdefaults.lock().1 = Some(kwdefaults); - } else if attr == bytecode::MakeFunctionFlags::ANNOTATIONS { - let annotations = match attr_value.downcast::() { - Ok(dict) => dict, - Err(obj) => { - return Err(vm.new_type_error(format!( - "__annotations__ must be a dict, not {}", - obj.class().name() - ))); - } - }; - *self.annotations.lock() = Some(annotations); - } else if attr == bytecode::MakeFunctionFlags::CLOSURE { - // For closure, we need special handling - // The closure tuple contains cell objects - let closure_tuple = attr_value - .clone() - .downcast_exact::(vm) - .map_err(|obj| { + match attr { + bytecode::MakeFunctionFlag::Defaults => { + let defaults = match attr_value.downcast::() { + Ok(tuple) => tuple, + Err(obj) => { + return Err(vm.new_type_error(format!( + "__defaults__ must be a tuple, not {}", + obj.class().name() + ))); + } + }; + self.defaults_and_kwdefaults.lock().0 = Some(defaults); + } + bytecode::MakeFunctionFlag::KwOnlyDefaults => { + let kwdefaults = match attr_value.downcast::() { + Ok(dict) => dict, + Err(obj) => { + return Err(vm.new_type_error(format!( + "__kwdefaults__ must be a dict, not {}", + obj.class().name() + ))); + } + }; + self.defaults_and_kwdefaults.lock().1 = Some(kwdefaults); + } + bytecode::MakeFunctionFlag::Annotations => { + let annotations = match attr_value.downcast::() { + Ok(dict) => dict, + Err(obj) => { + return Err(vm.new_type_error(format!( + "__annotations__ must be a dict, not {}", + obj.class().name() + ))); + } + }; + *self.annotations.lock() = Some(annotations); + } + bytecode::MakeFunctionFlag::Closure => { + let closure_tuple = attr_value + .clone() + .downcast_exact::(vm) + .map_err(|obj| { + vm.new_type_error(format!( + "closure must be a tuple, not {}", + obj.class().name() + )) + })? + .into_pyref(); + + self.closure = Some(closure_tuple.try_into_typed::(vm)?); + } + bytecode::MakeFunctionFlag::TypeParams => { + let type_params = attr_value.clone().downcast::().map_err(|_| { vm.new_type_error(format!( - "closure must be a tuple, not {}", - obj.class().name() + "__type_params__ must be a tuple, not {}", + attr_value.class().name() )) - })? - .into_pyref(); - - self.closure = Some(closure_tuple.try_into_typed::(vm)?); - } else if attr == bytecode::MakeFunctionFlags::TYPE_PARAMS { - let type_params = attr_value.clone().downcast::().map_err(|_| { - vm.new_type_error(format!( - "__type_params__ must be a tuple, not {}", - attr_value.class().name() - )) - })?; - *self.type_params.lock() = type_params; - } else if attr == bytecode::MakeFunctionFlags::ANNOTATE { - // PEP 649: Store the __annotate__ function closure - if !attr_value.is_callable() { - return Err(vm.new_type_error("__annotate__ must be callable".to_owned())); + })?; + *self.type_params.lock() = type_params; + } + bytecode::MakeFunctionFlag::Annotate => { + if !attr_value.is_callable() { + return Err(vm.new_type_error("__annotate__ must be callable")); + } + *self.annotate.lock() = Some(attr_value); } - *self.annotate.lock() = Some(attr_value); - } else { - unreachable!("This is a compiler bug"); } Ok(()) } } impl Py { + pub(crate) fn is_optimized_for_call_specialization(&self) -> bool { + self.code.flags.contains(bytecode::CodeFlags::OPTIMIZED) + } + pub fn invoke_with_locals( &self, func_args: FuncArgs, @@ -636,52 +642,80 @@ impl Py { new_v } + /// function_kind(SIMPLE_FUNCTION) equivalent for CALL specialization. + /// Returns true if: CO_OPTIMIZED, no VARARGS, no VARKEYWORDS, no kwonly args. + pub(crate) fn is_simple_for_call_specialization(&self) -> bool { + let code: &Py = &self.code; + let flags = code.flags; + flags.contains(bytecode::CodeFlags::OPTIMIZED) + && !flags.intersects(bytecode::CodeFlags::VARARGS | bytecode::CodeFlags::VARKEYWORDS) + && code.kwonlyarg_count == 0 + } + /// Check if this function is eligible for exact-args call specialization. - /// Returns true if: no VARARGS, no VARKEYWORDS, no kwonly args, not generator/coroutine, + /// Returns true if: CO_OPTIMIZED, no VARARGS, no VARKEYWORDS, no kwonly args, /// and effective_nargs matches co_argcount. pub(crate) fn can_specialize_call(&self, effective_nargs: u32) -> bool { let code: &Py = &self.code; let flags = code.flags; - flags.contains(bytecode::CodeFlags::NEWLOCALS) - && !flags.intersects( - bytecode::CodeFlags::VARARGS - | bytecode::CodeFlags::VARKEYWORDS - | bytecode::CodeFlags::GENERATOR - | bytecode::CodeFlags::COROUTINE, - ) + flags.contains(bytecode::CodeFlags::OPTIMIZED) + && !flags.intersects(bytecode::CodeFlags::VARARGS | bytecode::CodeFlags::VARKEYWORDS) && code.kwonlyarg_count == 0 && code.arg_count == effective_nargs } - /// Fast path for calling a simple function with exact positional args. - /// Skips FuncArgs allocation, prepend_arg, and fill_locals_from_args. - /// Only valid when: no VARARGS, no VARKEYWORDS, no kwonlyargs, not generator/coroutine, - /// and nargs == co_argcount. - pub fn invoke_exact_args(&self, mut args: Vec, vm: &VirtualMachine) -> PyResult { + /// Runtime guard for CALL_*_EXACT_ARGS specialization: check only argcount. + /// Other invariants are guaranteed by function versioning and specialization-time checks. + #[inline] + pub(crate) fn has_exact_argcount(&self, effective_nargs: u32) -> bool { + self.code.arg_count == effective_nargs + } + + /// Bytes required for this function's frame on RustPython's thread datastack. + /// Returns `None` for generator/coroutine code paths that do not push a + /// regular datastack-backed frame in the fast call path. + pub(crate) fn datastack_frame_size_bytes(&self) -> Option { + datastack_frame_size_bytes_for_code(&self.code) + } + + pub(crate) fn prepare_exact_args_frame( + &self, + mut args: Vec, + vm: &VirtualMachine, + ) -> FrameRef { let code: PyRef = (*self.code).to_owned(); debug_assert_eq!(args.len(), code.arg_count as usize); - debug_assert!(code.flags.contains(bytecode::CodeFlags::NEWLOCALS)); - debug_assert!(!code.flags.intersects( - bytecode::CodeFlags::VARARGS - | bytecode::CodeFlags::VARKEYWORDS - | bytecode::CodeFlags::GENERATOR - | bytecode::CodeFlags::COROUTINE - )); + debug_assert!(code.flags.contains(bytecode::CodeFlags::OPTIMIZED)); + debug_assert!( + !code + .flags + .intersects(bytecode::CodeFlags::VARARGS | bytecode::CodeFlags::VARKEYWORDS) + ); debug_assert_eq!(code.kwonlyarg_count, 0); + debug_assert!( + !code + .flags + .intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE) + ); + + let locals = if code.flags.contains(bytecode::CodeFlags::NEWLOCALS) { + None + } else { + Some(ArgMapping::from_dict_exact(self.globals.clone())) + }; let frame = Frame::new( code.clone(), - Scope::new(None, self.globals.clone()), + Scope::new(locals, self.globals.clone()), self.builtins.clone(), self.closure.as_ref().map_or(&[], |c| c.as_slice()), Some(self.to_owned().into()), - true, // Always use datastack (invoke_exact_args is never gen/coro) + true, // Exact-args fast path is only used for non-gen/coro functions. vm, ) .into_ref(&vm.ctx); - // Move args directly into fastlocals (no clone/refcount needed) { let fastlocals = unsafe { frame.fastlocals_mut() }; for (slot, arg) in fastlocals.iter_mut().zip(args.drain(..)) { @@ -689,7 +723,6 @@ impl Py { } } - // Handle cell2arg if let Some(cell2arg) = code.cell2arg.as_deref() { let fastlocals = unsafe { frame.fastlocals_mut() }; for (cell_idx, arg_idx) in cell2arg.iter().enumerate().filter(|(_, i)| **i != -1) { @@ -698,6 +731,36 @@ impl Py { } } + frame + } + + /// Fast path for calling a simple function with exact positional args. + /// Skips FuncArgs allocation, prepend_arg, and fill_locals_from_args. + /// Only valid when: CO_OPTIMIZED, no VARARGS, no VARKEYWORDS, no kwonlyargs, + /// and nargs == co_argcount. + pub fn invoke_exact_args(&self, args: Vec, vm: &VirtualMachine) -> PyResult { + let code: PyRef = (*self.code).to_owned(); + + debug_assert_eq!(args.len(), code.arg_count as usize); + debug_assert!(code.flags.contains(bytecode::CodeFlags::OPTIMIZED)); + debug_assert!( + !code + .flags + .intersects(bytecode::CodeFlags::VARARGS | bytecode::CodeFlags::VARKEYWORDS) + ); + debug_assert_eq!(code.kwonlyarg_count, 0); + + // Generator/coroutine code objects are SIMPLE_FUNCTION in call + // specialization classification, but their call path must still + // go through invoke() to produce generator/coroutine objects. + if code + .flags + .intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE) + { + return self.invoke(FuncArgs::from(args), vm); + } + let frame = self.prepare_exact_args_frame(args, vm); + let result = vm.run_frame(frame.clone()); unsafe { if let Some(base) = frame.materialize_localsplus() { @@ -708,6 +771,22 @@ impl Py { } } +pub(crate) fn datastack_frame_size_bytes_for_code(code: &Py) -> Option { + if code + .flags + .intersects(bytecode::CodeFlags::GENERATOR | bytecode::CodeFlags::COROUTINE) + { + return None; + } + let nlocalsplus = code + .varnames + .len() + .checked_add(code.cellvars.len())? + .checked_add(code.freevars.len())?; + let capacity = nlocalsplus.checked_add(code.max_stackdepth as usize)?; + capacity.checked_mul(core::mem::size_of::()) +} + impl PyPayload for PyFunction { #[inline] fn class(ctx: &Context) -> &'static Py { @@ -717,7 +796,7 @@ impl PyPayload for PyFunction { #[pyclass( with(GetDescriptor, Callable, Representable, Constructor), - flags(HAS_DICT, METHOD_DESCRIPTOR) + flags(HAS_DICT, HAS_WEAKREF, METHOD_DESCRIPTOR) )] impl PyFunction { #[pygetset] @@ -1170,7 +1249,7 @@ impl PyBoundMethod { #[pyclass( with(Callable, Comparable, GetAttr, Constructor, Representable), - flags(IMMUTABLETYPE) + flags(IMMUTABLETYPE, HAS_WEAKREF) )] impl PyBoundMethod { #[pymethod] @@ -1300,6 +1379,7 @@ pub(crate) fn vectorcall_function( let has_kwargs = kwnames.is_some_and(|kw| !kw.is_empty()); let is_simple = !has_kwargs + && code.flags.contains(bytecode::CodeFlags::OPTIMIZED) && !code.flags.contains(bytecode::CodeFlags::VARARGS) && !code.flags.contains(bytecode::CodeFlags::VARKEYWORDS) && code.kwonlyarg_count == 0 @@ -1310,37 +1390,8 @@ pub(crate) fn vectorcall_function( if is_simple && nargs == code.arg_count as usize { // FAST PATH: simple positional-only call, exact arg count. // Move owned args directly into fastlocals — no clone needed. - let locals = if code.flags.contains(bytecode::CodeFlags::NEWLOCALS) { - None // lazy allocation — most frames never access locals dict - } else { - Some(ArgMapping::from_dict_exact(zelf.globals.clone())) - }; - - let frame = Frame::new( - code.to_owned(), - Scope::new(locals, zelf.globals.clone()), - zelf.builtins.clone(), - zelf.closure.as_ref().map_or(&[], |c| c.as_slice()), - Some(zelf.to_owned().into()), - true, // Always use datastack (is_simple excludes gen/coro) - vm, - ) - .into_ref(&vm.ctx); - - { - let fastlocals = unsafe { frame.fastlocals_mut() }; - for (slot, arg) in fastlocals.iter_mut().zip(args.drain(..nargs)) { - *slot = Some(arg); - } - } - - if let Some(cell2arg) = code.cell2arg.as_deref() { - let fastlocals = unsafe { frame.fastlocals_mut() }; - for (cell_idx, arg_idx) in cell2arg.iter().enumerate().filter(|(_, i)| **i != -1) { - let x = fastlocals[*arg_idx as usize].take(); - frame.set_cell_contents(cell_idx, x); - } - } + args.truncate(nargs); + let frame = zelf.prepare_exact_args_frame(args, vm); let result = vm.run_frame(frame.clone()); unsafe { diff --git a/crates/vm/src/builtins/generator.rs b/crates/vm/src/builtins/generator.rs index fd822e9fbfe..2eee2fecd0d 100644 --- a/crates/vm/src/builtins/generator.rs +++ b/crates/vm/src/builtins/generator.rs @@ -34,7 +34,7 @@ impl PyPayload for PyGenerator { } #[pyclass( - flags(DISALLOW_INSTANTIATION), + flags(DISALLOW_INSTANTIATION, HAS_WEAKREF), with(Py, IterNext, Iterable, Representable, Destructor) )] impl PyGenerator { diff --git a/crates/vm/src/builtins/genericalias.rs b/crates/vm/src/builtins/genericalias.rs index 96da93dd8ef..1564229d186 100644 --- a/crates/vm/src/builtins/genericalias.rs +++ b/crates/vm/src/builtins/genericalias.rs @@ -84,7 +84,7 @@ impl Constructor for PyGenericAlias { Iterable, Representable ), - flags(BASETYPE) + flags(BASETYPE, HAS_WEAKREF) )] impl PyGenericAlias { pub fn new( @@ -155,10 +155,11 @@ impl PyGenericAlias { let mut parts = Vec::with_capacity(len); // Use indexed access so list mutation during repr causes IndexError for i in 0..len { - let item = - list.borrow_vec().get(i).cloned().ok_or_else(|| { - vm.new_index_error("list index out of range".to_owned()) - })?; + let item = list + .borrow_vec() + .get(i) + .cloned() + .ok_or_else(|| vm.new_index_error("list index out of range"))?; parts.push(repr_item(item, vm)?); } Ok(format!("[{}]", parts.join(", "))) @@ -712,9 +713,9 @@ impl crate::types::IterNext for PyGenericAliasIterator { None => return Ok(PyIterReturn::StopIteration(None)), }; // Create a starred GenericAlias from the original - let alias = obj.downcast_ref::().ok_or_else(|| { - vm.new_type_error("generic_alias_iterator expected GenericAlias".to_owned()) - })?; + let alias = obj + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("generic_alias_iterator expected GenericAlias"))?; let starred = PyGenericAlias::new(alias.origin.clone(), alias.args.clone(), true, vm); Ok(PyIterReturn::Return(starred.into_pyobject(vm))) } @@ -734,7 +735,7 @@ pub fn subscript_generic(type_params: PyObjectRef, vm: &VirtualMachine) -> PyRes PyTuple::new_ref(vec![type_params], &vm.ctx) }; - let args = crate::stdlib::typing::unpack_typevartuples(¶ms, vm)?; + let args = crate::stdlib::_typing::unpack_typevartuples(¶ms, vm)?; generic_alias_class.call((generic_type, args.to_pyobject(vm)), vm) } diff --git a/crates/vm/src/builtins/int.rs b/crates/vm/src/builtins/int.rs index 01863615ac1..a253506eba1 100644 --- a/crates/vm/src/builtins/int.rs +++ b/crates/vm/src/builtins/int.rs @@ -86,7 +86,7 @@ impl PyPayload for PyInt { } #[inline] - unsafe fn freelist_pop() -> Option> { + unsafe fn freelist_pop(_payload: &Self) -> Option> { INT_FREELIST .try_with(|fl| { let mut list = fl.take(); @@ -499,8 +499,15 @@ impl PyInt { if spec.is_empty() && !zelf.class().is(vm.ctx.types.int_type) { return Ok(zelf.as_object().str(vm)?.as_wtf8().to_owned()); } - FormatSpec::parse(spec.as_str()) - .and_then(|format_spec| format_spec.format_int(&zelf.value)) + let format_spec = + FormatSpec::parse(spec.as_str()).map_err(|err| err.into_pyexception(vm))?; + let result = if format_spec.has_locale_format() { + let locale = crate::format::get_locale_info(); + format_spec.format_int_locale(&zelf.value, &locale) + } else { + format_spec.format_int(&zelf.value) + }; + result .map(Wtf8Buf::from_string) .map_err(|err| err.into_pyexception(vm)) } diff --git a/crates/vm/src/builtins/list.rs b/crates/vm/src/builtins/list.rs index c13dea57169..cdb8a73ead2 100644 --- a/crates/vm/src/builtins/list.rs +++ b/crates/vm/src/builtins/list.rs @@ -105,7 +105,7 @@ impl PyPayload for PyList { } #[inline] - unsafe fn freelist_pop() -> Option> { + unsafe fn freelist_pop(_payload: &Self) -> Option> { LIST_FREELIST .try_with(|fl| { let mut list = fl.take(); @@ -286,7 +286,16 @@ impl PyList { fn _setitem(&self, needle: &PyObject, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { match SequenceIndex::try_from_borrowed_object(vm, needle, "list")? { - SequenceIndex::Int(index) => self.borrow_vec_mut().setitem_by_index(vm, index, value), + SequenceIndex::Int(index) => self + .borrow_vec_mut() + .setitem_by_index(vm, index, value) + .map_err(|e| { + if e.class().is(vm.ctx.exceptions.index_error) { + vm.new_index_error("list assignment index out of range".to_owned()) + } else { + e + } + }), SequenceIndex::Slice(slice) => { let sec = extract_cloned(&value, Ok, vm)?; self.borrow_vec_mut().setitem_by_slice(vm, slice, &sec) @@ -509,6 +518,13 @@ impl AsSequence for PyList { } else { zelf.borrow_vec_mut().delitem_by_index(vm, i) } + .map_err(|e| { + if e.class().is(vm.ctx.exceptions.index_error) { + vm.new_index_error("list assignment index out of range".to_owned()) + } else { + e + } + }) }), contains: atomic_func!(|seq, target, vm| { let zelf = PyList::sequence_downcast(seq); diff --git a/crates/vm/src/builtins/memory.rs b/crates/vm/src/builtins/memory.rs index a3403287dae..73eb1f1780b 100644 --- a/crates/vm/src/builtins/memory.rs +++ b/crates/vm/src/builtins/memory.rs @@ -549,7 +549,7 @@ impl Py { Iterable, Representable ), - flags(SEQUENCE) + flags(SEQUENCE, HAS_WEAKREF) )] impl PyMemoryView { #[pyclassmethod] @@ -700,7 +700,7 @@ impl PyMemoryView { self.try_not_released(vm)?; if self.desc.ndim() == 0 { // 0-dimensional memoryview has no length - Err(vm.new_type_error("0-dim memory has no length".to_owned())) + Err(vm.new_type_error("0-dim memory has no length")) } else { // shape for dim[0] Ok(self.desc.dim_desc[0].0) diff --git a/crates/vm/src/builtins/module.rs b/crates/vm/src/builtins/module.rs index 0dc2b571eae..cabaf1d63cb 100644 --- a/crates/vm/src/builtins/module.rs +++ b/crates/vm/src/builtins/module.rs @@ -286,7 +286,10 @@ impl Py { } } -#[pyclass(with(GetAttr, Initializer, Representable), flags(BASETYPE, HAS_DICT))] +#[pyclass( + with(GetAttr, Initializer, Representable), + flags(BASETYPE, HAS_DICT, HAS_WEAKREF) +)] impl PyModule { #[pyslot] fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { @@ -405,7 +408,7 @@ impl PyModule { } PySetterValue::Delete => { if dict.del_item(identifier!(vm, __annotations__), vm).is_err() { - return Err(vm.new_attribute_error("__annotations__".to_owned())); + return Err(vm.new_attribute_error("__annotations__")); } // Also clear __annotate__ dict.del_item(identifier!(vm, __annotate__), vm).ok(); diff --git a/crates/vm/src/builtins/namespace.rs b/crates/vm/src/builtins/namespace.rs index a32dda14586..4e872a172a4 100644 --- a/crates/vm/src/builtins/namespace.rs +++ b/crates/vm/src/builtins/namespace.rs @@ -28,7 +28,7 @@ impl PyPayload for PyNamespace { impl DefaultConstructor for PyNamespace {} #[pyclass( - flags(BASETYPE, HAS_DICT), + flags(BASETYPE, HAS_DICT, HAS_WEAKREF), with(Constructor, Initializer, Comparable, Representable) )] impl PyNamespace { diff --git a/crates/vm/src/builtins/object.rs b/crates/vm/src/builtins/object.rs index 8fed43cd5d7..002b05d38f1 100644 --- a/crates/vm/src/builtins/object.rs +++ b/crates/vm/src/builtins/object.rs @@ -43,8 +43,7 @@ impl Constructor for PyBaseObject { // Type has its own __new__, so object.__new__ is being called // with excess args. This is the first error case in CPython return Err(vm.new_type_error( - "object.__new__() takes exactly one argument (the type to instantiate)" - .to_owned(), + "object.__new__() takes exactly one argument (the type to instantiate)", )); } @@ -65,19 +64,6 @@ impl Constructor for PyBaseObject { } } - // more or less __new__ operator - // Only create dict if the class has HAS_DICT flag (i.e., __slots__ was not defined - // or __dict__ is in __slots__) - let dict = if cls - .slots - .flags - .has_feature(crate::types::PyTypeFlags::HAS_DICT) - { - Some(vm.ctx.new_dict()) - } else { - None - }; - // Ensure that all abstract methods are implemented before instantiating instance. if let Some(abs_methods) = cls.get_attr(identifier!(vm, __abstractmethods__)) && let Some(unimplemented_abstract_method_count) = abs_methods.length_opt(vm) @@ -110,7 +96,7 @@ impl Constructor for PyBaseObject { } } - Ok(crate::PyRef::new_ref(Self, cls, dict).into()) + generic_alloc(cls, 0, vm) } fn py_new(_cls: &Py, _args: Self::Args, _vm: &VirtualMachine) -> PyResult { @@ -118,6 +104,21 @@ impl Constructor for PyBaseObject { } } +pub(crate) fn generic_alloc(cls: PyTypeRef, _nitems: usize, vm: &VirtualMachine) -> PyResult { + // Only create dict if the class has HAS_DICT flag (i.e., __slots__ was not defined + // or __dict__ is in __slots__) + let dict = if cls + .slots + .flags + .has_feature(crate::types::PyTypeFlags::HAS_DICT) + { + Some(vm.ctx.new_dict()) + } else { + None + }; + Ok(crate::PyRef::new_ref(PyBaseObject, cls, dict).into()) +} + impl Initializer for PyBaseObject { type Args = FuncArgs; @@ -136,8 +137,7 @@ impl Initializer for PyBaseObject { // if (type->tp_init != object_init) → first error if typ_init != object_init { return Err(vm.new_type_error( - "object.__init__() takes exactly one argument (the instance to initialize)" - .to_owned(), + "object.__init__() takes exactly one argument (the instance to initialize)", )); } @@ -250,9 +250,7 @@ fn object_getstate_default(obj: &PyObject, required: bool, vm: &VirtualMachine) let borrowed_names = slot_names.borrow_vec(); // Check if slotnames changed during iteration if borrowed_names.len() != slot_names_len { - return Err(vm.new_runtime_error( - "__slotnames__ changed size during iteration".to_owned(), - )); + return Err(vm.new_runtime_error("__slotnames__ changed size during iteration")); } let name = borrowed_names[i].downcast_ref::().unwrap(); let Ok(value) = obj.get_attr(name, vm) else { @@ -464,6 +462,8 @@ impl PyBaseObject { if both_mutable || both_module { let has_dict = |typ: &Py| typ.slots.flags.has_feature(PyTypeFlags::HAS_DICT); + let has_weakref = + |typ: &Py| typ.slots.flags.has_feature(PyTypeFlags::HAS_WEAKREF); // Compare slots tuples let slots_equal = match ( current_cls @@ -484,6 +484,8 @@ impl PyBaseObject { if current_cls.slots.basicsize != cls.slots.basicsize || !slots_equal || has_dict(current_cls) != has_dict(&cls) + || has_weakref(current_cls) != has_weakref(&cls) + || current_cls.slots.member_count != cls.slots.member_count { return Err(vm.new_type_error(format!( "__class__ assignment: '{}' object layout differs from '{}'", @@ -561,8 +563,9 @@ pub fn object_set_dict(obj: PyObjectRef, dict: PyDictRef, vm: &VirtualMachine) - } pub fn init(ctx: &'static Context) { - // Manually set init slot - derive macro doesn't generate extend_slots + // Manually set alloc/init slots - derive macro doesn't generate extend_slots // for trait impl that overrides #[pyslot] method + ctx.types.object_type.slots.alloc.store(Some(generic_alloc)); ctx.types .object_type .slots @@ -714,7 +717,7 @@ fn reduce_newobj(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { } else { // args == NULL with non-empty kwargs is BadInternalCall let Some(args) = args else { - return Err(vm.new_system_error("bad internal call".to_owned())); + return Err(vm.new_system_error("bad internal call")); }; // Use copyreg.__newobj_ex__ let newobj = copyreg.get_attr("__newobj_ex__", vm)?; diff --git a/crates/vm/src/builtins/property.rs b/crates/vm/src/builtins/property.rs index 509307c7b00..d01477dfcbf 100644 --- a/crates/vm/src/builtins/property.rs +++ b/crates/vm/src/builtins/property.rs @@ -65,7 +65,10 @@ impl GetDescriptor for PyProperty { } } -#[pyclass(with(Constructor, Initializer, GetDescriptor), flags(BASETYPE))] +#[pyclass( + with(Constructor, Initializer, GetDescriptor), + flags(BASETYPE, HAS_WEAKREF) +)] impl PyProperty { // Helper method to get property name // Returns the name if available, None if not found, or propagates errors @@ -151,9 +154,7 @@ impl PyProperty { fn name_getter(&self, vm: &VirtualMachine) -> PyResult { match self.get_property_name(vm)? { Some(name) => Ok(name), - None => Err( - vm.new_attribute_error("'property' object has no attribute '__name__'".to_owned()) - ), + None => Err(vm.new_attribute_error("'property' object has no attribute '__name__'")), } } diff --git a/crates/vm/src/builtins/range.rs b/crates/vm/src/builtins/range.rs index 795ec230ba9..153a82bb43b 100644 --- a/crates/vm/src/builtins/range.rs +++ b/crates/vm/src/builtins/range.rs @@ -101,7 +101,7 @@ impl PyPayload for PyRange { } #[inline] - unsafe fn freelist_pop() -> Option> { + unsafe fn freelist_pop(_payload: &Self) -> Option> { RANGE_FREELIST .try_with(|fl| { let mut list = fl.take(); diff --git a/crates/vm/src/builtins/set.rs b/crates/vm/src/builtins/set.rs index 2b1e9c82e60..85e6b37fab0 100644 --- a/crates/vm/src/builtins/set.rs +++ b/crates/vm/src/builtins/set.rs @@ -531,7 +531,7 @@ fn reduce_set( AsNumber, Representable ), - flags(BASETYPE, _MATCH_SELF) + flags(BASETYPE, _MATCH_SELF, HAS_WEAKREF) )] impl PySet { fn __len__(&self) -> usize { @@ -996,7 +996,7 @@ impl Constructor for PyFrozenSet { } #[pyclass( - flags(BASETYPE, _MATCH_SELF), + flags(BASETYPE, _MATCH_SELF, HAS_WEAKREF), with( Constructor, AsSequence, diff --git a/crates/vm/src/builtins/singletons.rs b/crates/vm/src/builtins/singletons.rs index 7102e8ebfa3..9794a58d41b 100644 --- a/crates/vm/src/builtins/singletons.rs +++ b/crates/vm/src/builtins/singletons.rs @@ -110,9 +110,7 @@ impl AsNumber for PyNotImplemented { fn as_number() -> &'static PyNumberMethods { static AS_NUMBER: PyNumberMethods = PyNumberMethods { boolean: Some(|_number, vm| { - Err(vm.new_type_error( - "NotImplemented should not be used in a boolean context".to_owned(), - )) + Err(vm.new_type_error("NotImplemented should not be used in a boolean context")) }), ..PyNumberMethods::NOT_IMPLEMENTED }; diff --git a/crates/vm/src/builtins/slice.rs b/crates/vm/src/builtins/slice.rs index aeb3337c7d8..b46f7a3a56a 100644 --- a/crates/vm/src/builtins/slice.rs +++ b/crates/vm/src/builtins/slice.rs @@ -76,7 +76,7 @@ impl PyPayload for PySlice { } #[inline] - unsafe fn freelist_pop() -> Option> { + unsafe fn freelist_pop(_payload: &Self) -> Option> { SLICE_FREELIST .try_with(|fl| { let mut list = fl.take(); diff --git a/crates/vm/src/builtins/staticmethod.rs b/crates/vm/src/builtins/staticmethod.rs index a06267650a2..2554fa816aa 100644 --- a/crates/vm/src/builtins/staticmethod.rs +++ b/crates/vm/src/builtins/staticmethod.rs @@ -88,7 +88,7 @@ impl Initializer for PyStaticMethod { #[pyclass( with(Callable, GetDescriptor, Constructor, Initializer, Representable), - flags(BASETYPE, HAS_DICT) + flags(BASETYPE, HAS_DICT, HAS_WEAKREF) )] impl PyStaticMethod { #[pygetset] diff --git a/crates/vm/src/builtins/str.rs b/crates/vm/src/builtins/str.rs index 38102c18865..8e98fc6e5c4 100644 --- a/crates/vm/src/builtins/str.rs +++ b/crates/vm/src/builtins/str.rs @@ -1500,14 +1500,25 @@ impl PyRef { } pub fn concat_in_place(&mut self, other: &Wtf8, vm: &VirtualMachine) { - // TODO: call [A]Rc::get_mut on the str to try to mutate the data in place if other.is_empty() { return; } let mut s = Wtf8Buf::with_capacity(self.byte_len() + other.len()); s.push_wtf8(self.as_ref()); s.push_wtf8(other); - *self = PyStr::from(s).into_ref(&vm.ctx); + if self.as_object().strong_count() == 1 { + // SAFETY: strong_count()==1 guarantees unique ownership of this PyStr. + // Mutating payload in place preserves semantics while avoiding PyObject reallocation. + unsafe { + let payload = self.payload() as *const PyStr as *mut PyStr; + (*payload).data = PyStr::from(s).data; + (*payload) + .hash + .store(hash::SENTINEL, atomic::Ordering::Relaxed); + } + } else { + *self = PyStr::from(s).into_ref(&vm.ctx); + } } pub fn try_into_utf8(self, vm: &VirtualMachine) -> PyResult> { @@ -1678,13 +1689,23 @@ impl ToPyObject for Wtf8Buf { impl ToPyObject for char { fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { - vm.ctx.new_str(self).into() + let cp = self as u32; + if cp <= u8::MAX as u32 { + vm.ctx.latin1_char(cp as u8).into() + } else { + vm.ctx.new_str(self).into() + } } } impl ToPyObject for CodePoint { fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { - vm.ctx.new_str(self).into() + let cp = self.to_u32(); + if cp <= u8::MAX as u32 { + vm.ctx.latin1_char(cp as u8).into() + } else { + vm.ctx.new_str(self).into() + } } } @@ -1726,7 +1747,7 @@ impl ToPyObject for AsciiString { impl ToPyObject for AsciiChar { fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { - vm.ctx.new_str(self).into() + vm.ctx.latin1_char(u8::from(self)).into() } } diff --git a/crates/vm/src/builtins/traceback.rs b/crates/vm/src/builtins/traceback.rs index 975e81bb7f4..c6eac4e87e7 100644 --- a/crates/vm/src/builtins/traceback.rs +++ b/crates/vm/src/builtins/traceback.rs @@ -81,14 +81,14 @@ impl PyTraceback { let value = match value { PySetterValue::Assign(v) => v, PySetterValue::Delete => { - return Err(vm.new_type_error("can't delete tb_next attribute".to_owned())); + return Err(vm.new_type_error("can't delete tb_next attribute")); } }; if let Some(ref new_next) = value { let mut cursor = new_next.clone(); loop { if cursor.is(zelf) { - return Err(vm.new_value_error("traceback loop detected".to_owned())); + return Err(vm.new_value_error("traceback loop detected")); } let next = cursor.next.lock().clone(); match next { @@ -107,8 +107,8 @@ impl Constructor for PyTraceback { fn py_new(_cls: &Py, args: Self::Args, vm: &VirtualMachine) -> PyResult { let (next, frame, lasti, lineno) = args; - let lineno = OneIndexed::new(lineno) - .ok_or_else(|| vm.new_value_error("lineno must be positive".to_owned()))?; + let lineno = + OneIndexed::new(lineno).ok_or_else(|| vm.new_value_error("lineno must be positive"))?; Ok(Self::new(next, frame, lasti, lineno)) } } diff --git a/crates/vm/src/builtins/tuple.rs b/crates/vm/src/builtins/tuple.rs index 03f88f1b5fe..623f7144796 100644 --- a/crates/vm/src/builtins/tuple.rs +++ b/crates/vm/src/builtins/tuple.rs @@ -27,6 +27,8 @@ use crate::{ vm::VirtualMachine, }; use alloc::fmt; +use core::cell::Cell; +use core::ptr::NonNull; #[pyclass(module = false, name = "tuple", traverse = "manual")] pub struct PyTuple { @@ -53,14 +55,97 @@ unsafe impl Traverse for PyTuple { } } -// No freelist for PyTuple: structseq types (stat_result, struct_time, etc.) -// are static subtypes sharing the same Rust payload, making type-safe reuse -// impractical without a type-pointer comparison at push time. +// spell-checker:ignore MAXSAVESIZE +/// Per-size freelist storage for tuples, matching tuples[PyTuple_MAXSAVESIZE]. +/// Each bucket caches tuples of a specific element count (index = len - 1). +struct TupleFreeList { + buckets: [Vec>; Self::MAX_SAVE_SIZE], +} + +impl TupleFreeList { + /// Largest tuple size to cache on the freelist (sizes 1..=20). + const MAX_SAVE_SIZE: usize = 20; + const fn new() -> Self { + Self { + buckets: [const { Vec::new() }; Self::MAX_SAVE_SIZE], + } + } +} + +impl Default for TupleFreeList { + fn default() -> Self { + Self::new() + } +} + +impl Drop for TupleFreeList { + fn drop(&mut self) { + // Same safety pattern as FreeList::drop — free raw allocation + // without running payload destructors to avoid TLS-after-destruction panics. + let layout = crate::object::pyinner_layout::(); + for bucket in &mut self.buckets { + for ptr in bucket.drain(..) { + unsafe { + alloc::alloc::dealloc(ptr.as_ptr() as *mut u8, layout); + } + } + } + } +} + +thread_local! { + static TUPLE_FREELIST: Cell = const { Cell::new(TupleFreeList::new()) }; +} + impl PyPayload for PyTuple { + const MAX_FREELIST: usize = 2000; + const HAS_FREELIST: bool = true; + #[inline] fn class(ctx: &Context) -> &'static Py { ctx.types.tuple_type } + + #[inline] + unsafe fn freelist_push(obj: *mut PyObject) -> bool { + let len = unsafe { &*(obj as *const crate::Py) } + .elements + .len(); + if len == 0 || len > TupleFreeList::MAX_SAVE_SIZE { + return false; + } + TUPLE_FREELIST + .try_with(|fl| { + let mut list = fl.take(); + let bucket = &mut list.buckets[len - 1]; + let stored = if bucket.len() < Self::MAX_FREELIST { + bucket.push(unsafe { NonNull::new_unchecked(obj) }); + true + } else { + false + }; + fl.set(list); + stored + }) + .unwrap_or(false) + } + + #[inline] + unsafe fn freelist_pop(payload: &Self) -> Option> { + let len = payload.elements.len(); + if len == 0 || len > TupleFreeList::MAX_SAVE_SIZE { + return None; + } + TUPLE_FREELIST + .try_with(|fl| { + let mut list = fl.take(); + let result = list.buckets[len - 1].pop(); + fl.set(list); + result + }) + .ok() + .flatten() + } } pub trait IntoPyTuple { diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index cca8c4692e6..3a1cebb9d10 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -3,22 +3,22 @@ use super::{ PyUtf8StrRef, PyWeak, mappingproxy::PyMappingProxy, object, union_, }; use crate::{ - AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, - VirtualMachine, + AsObject, Context, Py, PyAtomicRef, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, + TryFromObject, VirtualMachine, builtins::{ PyBaseExceptionRef, descriptor::{ MemberGetter, MemberKind, MemberSetter, PyDescriptorOwned, PyMemberDef, PyMemberDescriptor, }, - function::PyCellRef, + function::{PyCellRef, PyFunction}, tuple::{IntoPyTuple, PyTuple}, }, class::{PyClassImpl, StaticType}, common::{ ascii, borrow::BorrowedValue, - lock::{PyRwLock, PyRwLockReadGuard}, + lock::{PyMutex, PyRwLock, PyRwLockReadGuard}, }, function::{FuncArgs, KwArgs, OptionalArg, PyMethodDef, PySetterValue}, object::{Traverse, TraverseFn}, @@ -81,7 +81,12 @@ struct TypeCacheEntry { /// Interned attribute name pointer (pointer equality check). name: AtomicPtr, /// Cached lookup result as raw pointer. null = empty. - /// The cache holds a strong reference (refcount incremented). + /// The cache holds a **borrowed** pointer (no refcount increment). + /// Safety: `type_cache_clear()` nullifies all entries during GC, + /// and `type_cache_clear_version()` nullifies entries when a type + /// is modified — both before the source dict entry is removed. + /// Types are always part of reference cycles (via `mro` self-reference) + /// so they are always collected by the cyclic GC (never refcount-freed). value: AtomicPtr, } @@ -149,13 +154,11 @@ impl TypeCacheEntry { self.sequence.load(Ordering::Relaxed) == previous } - /// Take the value out of this entry, returning the owned PyObjectRef. + /// Null out the cached value pointer. /// Caller must ensure no concurrent reads can observe this entry /// (version should be set to 0 first). - fn take_value(&self) -> Option { - let ptr = self.value.swap(core::ptr::null_mut(), Ordering::Relaxed); - // SAFETY: non-null ptr was stored via PyObjectRef::into_raw - NonNull::new(ptr).map(|nn| unsafe { PyObjectRef::from_raw(nn) }) + fn clear_value(&self) { + self.value.store(core::ptr::null_mut(), Ordering::Relaxed); } } @@ -180,45 +183,36 @@ fn type_cache_hash(version: u32, name: &'static PyStrInterned) -> usize { ((version ^ name_hash) as usize) & TYPE_CACHE_MASK } -/// Invalidate cache entries for a specific version tag and release values. +/// Invalidate cache entries for a specific version tag. /// Called from modified() when a type is changed. fn type_cache_clear_version(version: u32) { - let mut to_drop = Vec::new(); for entry in TYPE_CACHE.iter() { if entry.version.load(Ordering::Relaxed) == version { entry.begin_write(); if entry.version.load(Ordering::Relaxed) == version { entry.version.store(0, Ordering::Release); - if let Some(v) = entry.take_value() { - to_drop.push(v); - } + entry.clear_value(); } entry.end_write(); } } - drop(to_drop); } /// Clear all method cache entries (_PyType_ClearCache). -/// Called during GC collection to release strong references that might -/// prevent cycle collection. +/// Called during GC collection to nullify borrowed pointers before +/// the collector breaks cycles. /// /// Sets TYPE_CACHE_CLEARING to suppress cache re-population during the /// entire operation, preventing concurrent lookups from repopulating /// entries while we're clearing them. pub fn type_cache_clear() { TYPE_CACHE_CLEARING.store(true, Ordering::Release); - // Invalidate all entries and collect values. - let mut to_drop = Vec::new(); for entry in TYPE_CACHE.iter() { entry.begin_write(); entry.version.store(0, Ordering::Release); - if let Some(v) = entry.take_value() { - to_drop.push(v); - } + entry.clear_value(); entry.end_write(); } - drop(to_drop); TYPE_CACHE_CLEARING.store(false, Ordering::Release); } @@ -233,6 +227,9 @@ unsafe impl crate::object::Traverse for PyType { .iter() .map(|(_, v)| v.traverse(tracer_fn)) .count(); + if let Some(ext) = self.heaptype_ext.as_ref() { + ext.specialization_cache.traverse(tracer_fn); + } } /// type_clear: break reference cycles in type objects @@ -260,6 +257,9 @@ unsafe impl crate::object::Traverse for PyType { out.push(val); } } + if let Some(ext) = self.heaptype_ext.as_ref() { + ext.specialization_cache.clear_into(out); + } } } @@ -269,6 +269,99 @@ pub struct HeapTypeExt { pub qualname: PyRwLock, pub slots: Option>>, pub type_data: PyRwLock>, + pub specialization_cache: TypeSpecializationCache, +} + +pub struct TypeSpecializationCache { + pub init: PyAtomicRef>, + pub getitem: PyAtomicRef>, + pub getitem_version: AtomicU32, + // Serialize cache writes/invalidation similar to CPython's BEGIN_TYPE_LOCK. + write_lock: PyMutex<()>, + retired: PyRwLock>, +} + +impl TypeSpecializationCache { + fn new() -> Self { + Self { + init: PyAtomicRef::from(None::>), + getitem: PyAtomicRef::from(None::>), + getitem_version: AtomicU32::new(0), + write_lock: PyMutex::new(()), + retired: PyRwLock::new(Vec::new()), + } + } + + #[inline] + fn retire_old_function(&self, old: Option>) { + if let Some(old) = old { + self.retired.write().push(old.into()); + } + } + + #[inline] + fn swap_init(&self, new_init: Option>, vm: Option<&VirtualMachine>) { + if let Some(vm) = vm { + // Keep replaced refs alive for the currently executing frame, matching + // CPython-style "old pointer remains valid during ongoing execution" + // without accumulating global retired refs. + self.init.swap_to_temporary_refs(new_init, vm); + return; + } + // SAFETY: old value is moved to `retired`, so it stays alive while + // concurrent readers may still hold borrowed references. + let old = unsafe { self.init.swap(new_init) }; + self.retire_old_function(old); + } + + #[inline] + fn swap_getitem(&self, new_getitem: Option>, vm: Option<&VirtualMachine>) { + if let Some(vm) = vm { + self.getitem.swap_to_temporary_refs(new_getitem, vm); + return; + } + // SAFETY: old value is moved to `retired`, so it stays alive while + // concurrent readers may still hold borrowed references. + let old = unsafe { self.getitem.swap(new_getitem) }; + self.retire_old_function(old); + } + + #[inline] + fn invalidate_for_type_modified(&self) { + let _guard = self.write_lock.lock(); + // _spec_cache contract: type modification invalidates all cached + // specialization functions. + self.swap_init(None, None); + self.swap_getitem(None, None); + } + + fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) { + if let Some(init) = self.init.deref() { + tracer_fn(init.as_object()); + } + if let Some(getitem) = self.getitem.deref() { + tracer_fn(getitem.as_object()); + } + self.retired + .read() + .iter() + .map(|obj| obj.traverse(tracer_fn)) + .count(); + } + + fn clear_into(&self, out: &mut Vec) { + let _guard = self.write_lock.lock(); + let old_init = unsafe { self.init.swap(None) }; + if let Some(old_init) = old_init { + out.push(old_init.into()); + } + let old_getitem = unsafe { self.getitem.swap(None) }; + if let Some(old_getitem) = old_getitem { + out.push(old_getitem.into()); + } + self.getitem_version.store(0, Ordering::Release); + out.extend(self.retired.write().drain(..)); + } } pub struct PointerSlot(NonNull); @@ -396,6 +489,9 @@ impl PyType { /// Invalidate this type's version tag and cascade to all subclasses. pub fn modified(&self) { + if let Some(ext) = self.heaptype_ext.as_ref() { + ext.specialization_cache.invalidate_for_type_modified(); + } // If already invalidated, all subclasses must also be invalidated // (guaranteed by the MRO invariant in assign_version_tag). let old_version = self.tp_version_tag.load(Ordering::Acquire); @@ -403,9 +499,8 @@ impl PyType { return; } self.tp_version_tag.store(0, Ordering::SeqCst); - // Release strong references held by cache entries for this version. - // We hold owned refs that would prevent GC of class attributes after - // type deletion. + // Nullify borrowed pointers in cache entries for this version + // so they don't dangle after the dict is modified. type_cache_clear_version(old_version); let subclasses = self.subclasses.read(); for weak_ref in subclasses.iter() { @@ -450,6 +545,7 @@ impl PyType { qualname: PyRwLock::new(name), slots: None, type_data: PyRwLock::new(None), + specialization_cache: TypeSpecializationCache::new(), }; let base = bases[0].clone(); @@ -562,6 +658,14 @@ impl PyType { slots.flags |= PyTypeFlags::HAS_DICT } + // Inherit HAS_WEAKREF/MANAGED_WEAKREF from any base in MRO that has it + if mro + .iter() + .any(|b| b.slots.flags.has_feature(PyTypeFlags::HAS_WEAKREF)) + { + slots.flags |= PyTypeFlags::HAS_WEAKREF | PyTypeFlags::MANAGED_WEAKREF + } + // Inherit SEQUENCE and MAPPING flags from base classes Self::inherit_patma_flags(&mut slots, &bases); @@ -574,6 +678,11 @@ impl PyType { Self::inherit_readonly_slots(&mut slots, &base); + // Normalize: any type with HAS_WEAKREF gets MANAGED_WEAKREF + if slots.flags.has_feature(PyTypeFlags::HAS_WEAKREF) { + slots.flags |= PyTypeFlags::MANAGED_WEAKREF; + } + if let Some(qualname) = attrs.get(identifier!(ctx, __qualname__)) && !qualname.fast_isinstance(ctx.types.str_type) { @@ -623,6 +732,9 @@ impl PyType { if base.slots.flags.has_feature(PyTypeFlags::HAS_DICT) { slots.flags |= PyTypeFlags::HAS_DICT } + if base.slots.flags.has_feature(PyTypeFlags::HAS_WEAKREF) { + slots.flags |= PyTypeFlags::HAS_WEAKREF | PyTypeFlags::MANAGED_WEAKREF + } // Inherit SEQUENCE and MAPPING flags from base class // For static types, we only have a single base @@ -634,6 +746,11 @@ impl PyType { Self::inherit_readonly_slots(&mut slots, &base); + // Normalize: any type with HAS_WEAKREF gets MANAGED_WEAKREF + if slots.flags.has_feature(PyTypeFlags::HAS_WEAKREF) { + slots.flags |= PyTypeFlags::MANAGED_WEAKREF; + } + let bases = PyRwLock::new(vec![base.clone()]); let mro = base.mro_map_collect(|x| x.to_owned()); @@ -666,6 +783,7 @@ impl PyType { // slots are fully initialized by make_slots() Self::set_new(&new_type.slots, &new_type.base); + Self::set_alloc(&new_type.slots, &new_type.base); let weakref_type = super::PyWeak::static_type(); for base in new_type.bases.read().iter() { @@ -712,6 +830,7 @@ impl PyType { } Self::set_new(&self.slots, &self.base); + Self::set_alloc(&self.slots, &self.base); } fn set_new(slots: &PyTypeSlots, base: &Option) { @@ -726,6 +845,16 @@ impl PyType { } } + fn set_alloc(slots: &PyTypeSlots, base: &Option) { + if slots.alloc.load().is_none() { + slots.alloc.store( + base.as_ref() + .map(|base| base.slots.alloc.load()) + .unwrap_or(None), + ); + } + } + /// Inherit readonly slots from base type at creation time. /// These slots are not AtomicCell and must be set before the type is used. fn inherit_readonly_slots(slots: &mut PyTypeSlots, base: &Self) { @@ -756,9 +885,9 @@ impl PyType { } pub fn set_attr(&self, attr_name: &'static PyStrInterned, value: PyObjectRef) { - // Invalidate caches BEFORE modifying attributes so that cached - // descriptor pointers are still alive when type_cache_clear_version - // drops the cache's strong references. + // Invalidate caches BEFORE modifying attributes so that borrowed + // pointers in cache entries are nullified while the source objects + // are still alive. self.modified(); self.attributes.write().insert(attr_name, value); } @@ -769,6 +898,96 @@ impl PyType { self.find_name_in_mro(attr_name) } + /// Cache __init__ for CALL_ALLOC_AND_ENTER_INIT specialization. + /// The cache is valid only when guarded by the type version check. + pub(crate) fn cache_init_for_specialization( + &self, + init: PyRef, + tp_version: u32, + vm: &VirtualMachine, + ) -> bool { + let Some(ext) = self.heaptype_ext.as_ref() else { + return false; + }; + if tp_version == 0 { + return false; + } + if self.tp_version_tag.load(Ordering::Acquire) != tp_version { + return false; + } + let _guard = ext.specialization_cache.write_lock.lock(); + if self.tp_version_tag.load(Ordering::Acquire) != tp_version { + return false; + } + ext.specialization_cache.swap_init(Some(init), Some(vm)); + true + } + + /// Read cached __init__ for CALL_ALLOC_AND_ENTER_INIT specialization. + pub(crate) fn get_cached_init_for_specialization( + &self, + tp_version: u32, + ) -> Option> { + let ext = self.heaptype_ext.as_ref()?; + if tp_version == 0 { + return None; + } + if self.tp_version_tag.load(Ordering::Acquire) != tp_version { + return None; + } + ext.specialization_cache + .init + .to_owned_ordering(Ordering::Acquire) + } + + /// Cache __getitem__ for BINARY_OP_SUBSCR_GETITEM specialization. + /// The cache is valid only when guarded by the type version check. + pub(crate) fn cache_getitem_for_specialization( + &self, + getitem: PyRef, + tp_version: u32, + vm: &VirtualMachine, + ) -> bool { + let Some(ext) = self.heaptype_ext.as_ref() else { + return false; + }; + if tp_version == 0 { + return false; + } + let _guard = ext.specialization_cache.write_lock.lock(); + if self.tp_version_tag.load(Ordering::Acquire) != tp_version { + return false; + } + let func_version = getitem.get_version_for_current_state(); + if func_version == 0 { + return false; + } + ext.specialization_cache + .swap_getitem(Some(getitem), Some(vm)); + ext.specialization_cache + .getitem_version + .store(func_version, Ordering::Relaxed); + true + } + + /// Read cached __getitem__ for BINARY_OP_SUBSCR_GETITEM specialization. + pub(crate) fn get_cached_getitem_for_specialization(&self) -> Option<(PyRef, u32)> { + let ext = self.heaptype_ext.as_ref()?; + // Match CPython check order: pointer (Acquire) then function version. + let getitem = ext + .specialization_cache + .getitem + .to_owned_ordering(Ordering::Acquire)?; + let cached_version = ext + .specialization_cache + .getitem_version + .load(Ordering::Relaxed); + if cached_version == 0 { + return None; + } + Some((getitem, cached_version)) + } + pub fn get_direct_attr(&self, attr_name: &'static PyStrInterned) -> Option { self.attributes.read().get(attr_name).cloned() } @@ -841,19 +1060,13 @@ impl PyType { entry.begin_write(); // Invalidate first to prevent readers from seeing partial state entry.version.store(0, Ordering::Release); - // Swap in new value (refcount held by cache) - let new_ptr = found.clone().into_raw().as_ptr(); - let old_ptr = entry.value.swap(new_ptr, Ordering::Relaxed); + // Store borrowed pointer (no refcount increment). + let new_ptr = &**found as *const PyObject as *mut PyObject; + entry.value.store(new_ptr, Ordering::Relaxed); entry.name.store(name_ptr, Ordering::Relaxed); // Activate entry — Release ensures value/name writes are visible entry.version.store(assigned, Ordering::Release); entry.end_write(); - // Drop previous occupant (its version was already invalidated) - if !old_ptr.is_null() { - unsafe { - drop(PyObjectRef::from_raw(NonNull::new_unchecked(old_ptr))); - } - } } result @@ -1083,7 +1296,7 @@ impl Py { AsNumber, Representable ), - flags(BASETYPE) + flags(BASETYPE, HAS_DICT, HAS_WEAKREF) )] impl PyType { #[pygetset] @@ -1264,7 +1477,7 @@ impl PyType { fn set___annotate__(&self, value: PySetterValue, vm: &VirtualMachine) -> PyResult<()> { let value = match value { PySetterValue::Delete => { - return Err(vm.new_type_error("cannot delete __annotate__ attribute".to_owned())); + return Err(vm.new_type_error("cannot delete __annotate__ attribute")); } PySetterValue::Assign(v) => v, }; @@ -1277,7 +1490,7 @@ impl PyType { } if !vm.is_none(&value) && !value.is_callable() { - return Err(vm.new_type_error("__annotate__ must be callable or None".to_owned())); + return Err(vm.new_type_error("__annotate__ must be callable or None")); } let mut attrs = self.attributes.write(); @@ -1394,7 +1607,7 @@ impl PyType { .is_some() }; if !removed { - return Err(vm.new_attribute_error("__annotations__".to_owned())); + return Err(vm.new_attribute_error("__annotations__")); } if has_annotations { attrs.swap_remove(identifier!(vm, __annotations_cache__)); @@ -1709,141 +1922,125 @@ impl Constructor for PyType { attributes.insert(identifier!(vm, __hash__), vm.ctx.none.clone().into()); } - let (heaptype_slots, add_dict): (Option>>, bool) = - if let Some(x) = attributes.get(identifier!(vm, __slots__)) { - // Check if __slots__ is bytes - not allowed - if x.class().is(vm.ctx.types.bytes_type) { - return Err(vm.new_type_error("__slots__ items must be strings, not 'bytes'")); - } - - let slots = if x.class().is(vm.ctx.types.str_type) { - let x = unsafe { x.downcast_unchecked_ref::() }; - PyTuple::new_ref_typed(vec![x.to_owned()], &vm.ctx) - } else { - let iter = x.get_iter(vm)?; - let elements = { - let mut elements = Vec::new(); - while let PyIterReturn::Return(element) = iter.next(vm)? { - // Check if any slot item is bytes - if element.class().is(vm.ctx.types.bytes_type) { - return Err(vm.new_type_error( - "__slots__ items must be strings, not 'bytes'", - )); - } - elements.push(element); - } - elements - }; - let tuple = elements.into_pytuple(vm); - tuple.try_into_typed(vm)? - }; - - // Check if base has itemsize > 0 - can't add arbitrary slots to variable-size types - // Types like int, bytes, tuple have itemsize > 0 and don't allow custom slots - // But types like weakref.ref have itemsize = 0 and DO allow slots - let has_custom_slots = slots - .iter() - .any(|s| !matches!(s.as_bytes(), b"__dict__" | b"__weakref__")); - if has_custom_slots && base.slots.itemsize > 0 { - return Err(vm.new_type_error(format!( - "nonempty __slots__ not supported for subtype of '{}'", - base.name() - ))); - } - - // Validate slot names and track duplicates - let mut seen_dict = false; - let mut seen_weakref = false; - for slot in slots.iter() { - // Use isidentifier for validation (handles Unicode properly) - if !slot.isidentifier() { - return Err(vm.new_type_error("__slots__ must be identifiers")); - } - - let slot_name = slot.as_bytes(); + let (heaptype_slots, add_dict, add_weakref): ( + Option>>, + bool, + bool, + ) = if let Some(x) = attributes.get(identifier!(vm, __slots__)) { + // Check if __slots__ is bytes - not allowed + if x.class().is(vm.ctx.types.bytes_type) { + return Err(vm.new_type_error("__slots__ items must be strings, not 'bytes'")); + } - // Check for duplicate __dict__ - if slot_name == b"__dict__" { - if seen_dict { + let slots = if x.class().is(vm.ctx.types.str_type) { + let x = unsafe { x.downcast_unchecked_ref::() }; + PyTuple::new_ref_typed(vec![x.to_owned()], &vm.ctx) + } else { + let iter = x.get_iter(vm)?; + let elements = { + let mut elements = Vec::new(); + while let PyIterReturn::Return(element) = iter.next(vm)? { + // Check if any slot item is bytes + if element.class().is(vm.ctx.types.bytes_type) { return Err( - vm.new_type_error("__dict__ slot disallowed: we already got one") + vm.new_type_error("__slots__ items must be strings, not 'bytes'") ); } - seen_dict = true; + elements.push(element); } + elements + }; + let tuple = elements.into_pytuple(vm); + tuple.try_into_typed(vm)? + }; - // Check for duplicate __weakref__ - if slot_name == b"__weakref__" { - if seen_weakref { - return Err(vm.new_type_error( - "__weakref__ slot disallowed: we already got one", - )); - } - seen_weakref = true; - } + // Check if base has itemsize > 0 - can't add arbitrary slots to variable-size types + // Types like int, bytes, tuple have itemsize > 0 and don't allow custom slots + // But types like weakref.ref have itemsize = 0 and DO allow slots + let has_custom_slots = slots + .iter() + .any(|s| !matches!(s.as_bytes(), b"__dict__" | b"__weakref__")); + if has_custom_slots && base.slots.itemsize > 0 { + return Err(vm.new_type_error(format!( + "nonempty __slots__ not supported for subtype of '{}'", + base.name() + ))); + } - // Check if slot name conflicts with class attributes - if attributes.contains_key(vm.ctx.intern_str(slot.as_wtf8())) { - return Err(vm.new_value_error(format!( - "'{}' in __slots__ conflicts with a class variable", - slot.as_wtf8() - ))); - } + // Validate slot names and track duplicates + let mut seen_dict = false; + let mut seen_weakref = false; + for slot in slots.iter() { + // Use isidentifier for validation (handles Unicode properly) + if !slot.isidentifier() { + return Err(vm.new_type_error("__slots__ must be identifiers")); } - // Check if base class already has __dict__ - can't redefine it - if seen_dict && base.slots.flags.has_feature(PyTypeFlags::HAS_DICT) { - return Err(vm.new_type_error("__dict__ slot disallowed: we already got one")); - } + let slot_name = slot.as_bytes(); - // Check if base class already has __weakref__ - can't redefine it - // A base has weakref support if: - // 1. It's a heap type without explicit __slots__ (automatic weakref), OR - // 2. It's a heap type with __weakref__ in its __slots__ - if seen_weakref { - let base_has_weakref = if let Some(ref ext) = base.heaptype_ext { - match &ext.slots { - // Heap type without __slots__ - has automatic weakref - None => true, - // Heap type with __slots__ - check if __weakref__ is in slots - Some(base_slots) => { - base_slots.iter().any(|s| s.as_bytes() == b"__weakref__") - } - } - } else { - // Builtin type - check if it has __weakref__ descriptor - let weakref_name = vm.ctx.intern_str("__weakref__"); - base.attributes.read().contains_key(weakref_name) - }; + // Check for duplicate __dict__ + if slot_name == b"__dict__" { + if seen_dict { + return Err( + vm.new_type_error("__dict__ slot disallowed: we already got one") + ); + } + seen_dict = true; + } - if base_has_weakref { + // Check for duplicate __weakref__ + if slot_name == b"__weakref__" { + if seen_weakref { return Err( vm.new_type_error("__weakref__ slot disallowed: we already got one") ); } + seen_weakref = true; } - // Check if __dict__ is in slots - let dict_name = "__dict__"; - let has_dict = slots.iter().any(|s| s.as_wtf8() == dict_name); - - // Filter out __dict__ from slots - let filtered_slots = if has_dict { - let filtered: Vec = slots - .iter() - .filter(|s| s.as_wtf8() != dict_name) - .cloned() - .collect(); - PyTuple::new_ref_typed(filtered, &vm.ctx) - } else { - slots - }; + // Check if slot name conflicts with class attributes + if attributes.contains_key(vm.ctx.intern_str(slot.as_wtf8())) { + return Err(vm.new_value_error(format!( + "'{}' in __slots__ conflicts with a class variable", + slot.as_wtf8() + ))); + } + } + + // Check if base class already has __dict__ - can't redefine it + if seen_dict && base.slots.flags.has_feature(PyTypeFlags::HAS_DICT) { + return Err(vm.new_type_error("__dict__ slot disallowed: we already got one")); + } + + // Check if base class already has __weakref__ - can't redefine it + if seen_weakref && base.slots.flags.has_feature(PyTypeFlags::HAS_WEAKREF) { + return Err(vm.new_type_error("__weakref__ slot disallowed: we already got one")); + } + + // Check if __dict__ or __weakref__ is in slots + let dict_name = "__dict__"; + let weakref_name = "__weakref__"; + let has_dict = slots.iter().any(|s| s.as_wtf8() == dict_name); + let add_weakref = seen_weakref; - (Some(filtered_slots), has_dict) + // Filter out __dict__ and __weakref__ from slots + // (they become descriptors, not member slots) + let filtered_slots = if has_dict || add_weakref { + let filtered: Vec = slots + .iter() + .filter(|s| s.as_wtf8() != dict_name && s.as_wtf8() != weakref_name) + .cloned() + .collect(); + PyTuple::new_ref_typed(filtered, &vm.ctx) } else { - (None, false) + slots }; + (Some(filtered_slots), has_dict, add_weakref) + } else { + (None, false, false) + }; + // FIXME: this is a temporary fix. multi bases with multiple slots will break object let base_member_count = bases .iter() @@ -1867,6 +2064,14 @@ impl Constructor for PyType { flags |= PyTypeFlags::HAS_DICT | PyTypeFlags::MANAGED_DICT; } + // Add HAS_WEAKREF if: + // 1. __slots__ is not defined (automatic weakref support), OR + // 2. __weakref__ is in __slots__ + let may_add_weakref = !base.slots.flags.has_feature(PyTypeFlags::HAS_WEAKREF); + if (heaptype_slots.is_none() && may_add_weakref) || add_weakref { + flags |= PyTypeFlags::HAS_WEAKREF | PyTypeFlags::MANAGED_WEAKREF; + } + let (slots, heaptype_ext) = { let slots = PyTypeSlots { flags, @@ -1879,6 +2084,7 @@ impl Constructor for PyType { qualname: PyRwLock::new(qualname), slots: heaptype_slots.clone(), type_data: PyRwLock::new(None), + specialization_cache: TypeSpecializationCache::new(), }; (slots, heaptype_ext) }; @@ -1899,9 +2105,9 @@ impl Constructor for PyType { let class_name = typ.name().to_string(); for member in slots.as_slice() { // Apply name mangling for private attributes (__x -> _ClassName__x) - let member_str = member.to_str().ok_or_else(|| { - vm.new_type_error("__slots__ must be valid UTF-8 strings".to_owned()) - })?; + let member_str = member + .to_str() + .ok_or_else(|| vm.new_type_error("__slots__ must be valid UTF-8 strings"))?; let mangled_name = mangle_name(&class_name, member_str); let member_def = PyMemberDef { name: mangled_name.clone(), @@ -1965,6 +2171,29 @@ impl Constructor for PyType { } } + // Add __weakref__ descriptor for types with HAS_WEAKREF + if typ.slots.flags.has_feature(PyTypeFlags::HAS_WEAKREF) { + let __weakref__ = vm.ctx.intern_str("__weakref__"); + let has_inherited_weakref = typ + .mro + .read() + .iter() + .any(|base| base.attributes.read().contains_key(&__weakref__)); + if !typ.attributes.read().contains_key(&__weakref__) && !has_inherited_weakref { + unsafe { + let descriptor = vm.ctx.new_getset( + "__weakref__", + &typ, + subtype_get_weakref, + subtype_set_weakref, + ); + typ.attributes + .write() + .insert(__weakref__, descriptor.into()); + } + } + } + // Set __doc__ to None if not already present in the type's dict // This matches CPython's behavior in type_dict_set_doc (typeobject.c) // which ensures every type has a __doc__ entry in its dict @@ -2056,10 +2285,10 @@ impl Initializer for PyType { fn slot_init(_zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { // type.__init__() takes 1 or 3 arguments if args.args.len() == 1 && !args.kwargs.is_empty() { - return Err(vm.new_type_error("type.__init__() takes no keyword arguments".to_owned())); + return Err(vm.new_type_error("type.__init__() takes no keyword arguments")); } if args.args.len() != 1 && args.args.len() != 3 { - return Err(vm.new_type_error("type.__init__() takes 1 or 3 arguments".to_owned())); + return Err(vm.new_type_error("type.__init__() takes 1 or 3 arguments")); } Ok(()) } @@ -2188,7 +2417,7 @@ impl Py { #[pymethod] fn __instancecheck__(&self, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { - // Use real_is_instance to avoid infinite recursion, matching CPython's behavior + // Use real_is_instance to avoid infinite recursion obj.real_is_instance(self.as_object(), vm) } @@ -2273,7 +2502,7 @@ impl Callable for PyType { return Ok(args.args[0].obj_type()); } if num_args != 3 { - return Err(vm.new_type_error("type() takes 1 or 3 arguments".to_owned())); + return Err(vm.new_type_error("type() takes 1 or 3 arguments")); } } @@ -2400,6 +2629,21 @@ fn subtype_set_dict(obj: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) - } } +// subtype_get_weakref +fn subtype_get_weakref(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { + // Return the first weakref in the weakref list, or None + let weakref = obj.get_weakrefs(); + Ok(weakref.unwrap_or_else(|| vm.ctx.none())) +} + +// subtype_set_weakref: __weakref__ is read-only +fn subtype_set_weakref(obj: PyObjectRef, _value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + Err(vm.new_attribute_error(format!( + "attribute '__weakref__' of '{}' objects is not writable", + obj.class().name() + ))) +} + /* * The magical type type */ diff --git a/crates/vm/src/builtins/union.rs b/crates/vm/src/builtins/union.rs index 830465e49f5..0f7ce123721 100644 --- a/crates/vm/src/builtins/union.rs +++ b/crates/vm/src/builtins/union.rs @@ -9,7 +9,7 @@ use crate::{ convert::ToPyObject, function::PyComparisonValue, protocol::{PyMappingMethods, PyNumberMethods}, - stdlib::typing::{TypeAliasType, call_typing_func_object}, + stdlib::_typing::{TypeAliasType, call_typing_func_object}, types::{AsMapping, AsNumber, Comparable, GetAttr, Hashable, PyComparisonOp, Representable}, }; use alloc::fmt; @@ -51,7 +51,7 @@ impl PyUnion { }) } - /// Direct access to args field, matching CPython's _Py_union_args + /// Direct access to args field (_Py_union_args) #[inline] pub fn args(&self) -> &Py { &self.args @@ -98,7 +98,7 @@ impl PyUnion { } #[pyclass( - flags(DISALLOW_INSTANTIATION), + flags(DISALLOW_INSTANTIATION, HAS_WEAKREF), with(Hashable, Comparable, AsMapping, AsNumber, Representable) )] impl PyUnion { @@ -292,8 +292,8 @@ fn dedup_and_flatten_args(args: &Py, vm: &VirtualMachine) -> PyResult = Vec::with_capacity(args.len()); diff --git a/crates/vm/src/builtins/weakref.rs b/crates/vm/src/builtins/weakref.rs index d3087eedb9e..e1b9545252d 100644 --- a/crates/vm/src/builtins/weakref.rs +++ b/crates/vm/src/builtins/weakref.rs @@ -45,9 +45,9 @@ impl Constructor for PyWeak { // PyArg_UnpackTuple: only process positional args, ignore kwargs. // Subclass __init__ will handle extra kwargs. let mut positional = args.args.into_iter(); - let referent = positional.next().ok_or_else(|| { - vm.new_type_error("__new__ expected at least 1 argument, got 0".to_owned()) - })?; + let referent = positional + .next() + .ok_or_else(|| vm.new_type_error("__new__ expected at least 1 argument, got 0"))?; let callback = positional.next(); if let Some(_extra) = positional.next() { return Err(vm.new_type_error(format!( diff --git a/crates/vm/src/bytes_inner.rs b/crates/vm/src/bytes_inner.rs index d8e4a6c8eff..2318415f0fe 100644 --- a/crates/vm/src/bytes_inner.rs +++ b/crates/vm/src/bytes_inner.rs @@ -461,7 +461,7 @@ impl PyBytesInner { match invalid_char { None => Err(vm.new_value_error( - "fromhex() arg must contain an even number of hexadecimal digits".to_owned(), + "fromhex() arg must contain an even number of hexadecimal digits", )), Some(i) => Err(vm.new_value_error(format!( "non-hexadecimal number found in fromhex() arg at position {i}" @@ -474,9 +474,9 @@ impl PyBytesInner { if let Some(s) = string.downcast_ref::() { Self::fromhex(s.as_bytes(), vm) } else if let Ok(buffer) = PyBuffer::try_from_borrowed_object(vm, &string) { - let borrowed = buffer.as_contiguous().ok_or_else(|| { - vm.new_buffer_error("fromhex() requires a contiguous buffer".to_owned()) - })?; + let borrowed = buffer + .as_contiguous() + .ok_or_else(|| vm.new_buffer_error("fromhex() requires a contiguous buffer"))?; Self::fromhex(&borrowed, vm) } else { Err(vm.new_type_error(format!( diff --git a/crates/vm/src/convert/try_from.rs b/crates/vm/src/convert/try_from.rs index b8d1b53e2e7..85d6f5e20e3 100644 --- a/crates/vm/src/convert/try_from.rs +++ b/crates/vm/src/convert/try_from.rs @@ -127,15 +127,13 @@ impl TryFromObject for core::time::Duration { if let Some(float) = obj.downcast_ref::() { let f = float.to_f64(); if f.is_nan() { - return Err(vm.new_value_error("Invalid value NaN (not a number)".to_owned())); + return Err(vm.new_value_error("Invalid value NaN (not a number)")); } if f < 0.0 { - return Err(vm.new_value_error("negative duration".to_owned())); + return Err(vm.new_value_error("negative duration")); } if !f.is_finite() || f > u64::MAX as f64 { - return Err(vm.new_overflow_error( - "timestamp too large to convert to C PyTime_t".to_owned(), - )); + return Err(vm.new_overflow_error("timestamp too large to convert to C PyTime_t")); } // Convert float to Duration using floor rounding (_PyTime_ROUND_FLOOR) let secs = f.trunc() as u64; diff --git a/crates/vm/src/coroutine.rs b/crates/vm/src/coroutine.rs index ac7aeba5443..07158c48859 100644 --- a/crates/vm/src/coroutine.rs +++ b/crates/vm/src/coroutine.rs @@ -115,27 +115,12 @@ impl Coro { result } - pub fn send( + fn finalize_send_result( &self, jen: &PyObject, - value: PyObjectRef, + result: PyResult, vm: &VirtualMachine, ) -> PyResult { - if self.closed.load() { - return Ok(PyIterReturn::StopIteration(None)); - } - self.frame.locals_to_fast(vm)?; - let value = if self.frame.lasti() > 0 { - Some(value) - } else if !vm.is_none(&value) { - return Err(vm.new_type_error(format!( - "can't send non-None value to a just-started {}", - gen_name(jen, vm), - ))); - } else { - None - }; - let result = self.run_with_context(jen, vm, |f| f.resume(value, vm)); self.maybe_close(&result); match result { Ok(exec_res) => Ok(exec_res.into_iter_return(vm)), @@ -158,6 +143,44 @@ impl Coro { } } + pub(crate) fn send_none(&self, jen: &PyObject, vm: &VirtualMachine) -> PyResult { + if self.closed.load() { + return Ok(PyIterReturn::StopIteration(None)); + } + self.frame.locals_to_fast(vm)?; + let value = if self.frame.lasti() > 0 { + Some(vm.ctx.none()) + } else { + None + }; + let result = self.run_with_context(jen, vm, |f| f.resume(value, vm)); + self.finalize_send_result(jen, result, vm) + } + + pub fn send( + &self, + jen: &PyObject, + value: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult { + if self.closed.load() { + return Ok(PyIterReturn::StopIteration(None)); + } + self.frame.locals_to_fast(vm)?; + let value = if self.frame.lasti() > 0 { + Some(value) + } else if !vm.is_none(&value) { + return Err(vm.new_type_error(format!( + "can't send non-None value to a just-started {}", + gen_name(jen, vm), + ))); + } else { + None + }; + let result = self.run_with_context(jen, vm, |f| f.resume(value, vm)); + self.finalize_send_result(jen, result, vm) + } + pub fn throw( &self, jen: &PyObject, @@ -166,17 +189,13 @@ impl Coro { exc_tb: PyObjectRef, vm: &VirtualMachine, ) -> PyResult { - // Validate throw arguments (matching CPython _gen_throw) + // Validate throw arguments (_gen_throw) if exc_type.fast_isinstance(vm.ctx.exceptions.base_exception_type) && !vm.is_none(&exc_val) { - return Err( - vm.new_type_error("instance exception may not have a separate value".to_owned()) - ); + return Err(vm.new_type_error("instance exception may not have a separate value")); } if !vm.is_none(&exc_tb) && !exc_tb.fast_isinstance(vm.ctx.types.traceback_type) { - return Err( - vm.new_type_error("throw() third argument must be a traceback object".to_owned()) - ); + return Err(vm.new_type_error("throw() third argument must be a traceback object")); } if self.closed.load() { return Err(vm.normalize_exception(exc_type, exc_val, exc_tb)?); @@ -299,7 +318,7 @@ pub fn get_awaitable_iter(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { .contains(crate::bytecode::CodeFlags::ITERABLE_COROUTINE) }) { - return Err(vm.new_type_error("__await__() returned a coroutine".to_owned())); + return Err(vm.new_type_error("__await__() returned a coroutine")); } if !PyIter::check(&result) { return Err(vm.new_type_error(format!( diff --git a/crates/vm/src/exception_group.rs b/crates/vm/src/exception_group.rs index f6abdee0fab..02342e4003d 100644 --- a/crates/vm/src/exception_group.rs +++ b/crates/vm/src/exception_group.rs @@ -276,9 +276,9 @@ pub(super) mod types { // Validate non-empty if exceptions.is_empty() { - return Err(vm.new_value_error( - "second argument (exceptions) must be a non-empty sequence".to_owned(), - )); + return Err( + vm.new_value_error("second argument (exceptions) must be a non-empty sequence") + ); } // Validate all items are BaseException instances diff --git a/crates/vm/src/exceptions.rs b/crates/vm/src/exceptions.rs index b72b89b4768..f32005bd348 100644 --- a/crates/vm/src/exceptions.rs +++ b/crates/vm/src/exceptions.rs @@ -2171,7 +2171,7 @@ pub(super) mod types { fn characters_written(&self, vm: &VirtualMachine) -> PyResult { let written = self.written.load(); if written == -1 { - Err(vm.new_attribute_error("characters_written".to_owned())) + Err(vm.new_attribute_error("characters_written")) } else { Ok(written) } @@ -2187,7 +2187,7 @@ pub(super) mod types { None => { // Deleting the attribute if self.written.load() == -1 { - Err(vm.new_attribute_error("characters_written".to_owned())) + Err(vm.new_attribute_error("characters_written")) } else { self.written.store(-1); Ok(()) @@ -2198,9 +2198,7 @@ pub(super) mod types { .try_index(vm)? .try_to_primitive::(vm) .map_err(|_| { - vm.new_value_error( - "cannot convert characters_written value to isize".to_owned(), - ) + vm.new_value_error("cannot convert characters_written value to isize") })?; self.written.store(n); Ok(()) @@ -2714,14 +2712,13 @@ fn check_except_star_type_valid(match_type: &PyObjectRef, vm: &VirtualMachine) - // Must be a subclass of BaseException if !exc_type.is_subclass(&base_exc, vm)? { return Err(vm.new_type_error( - "catching classes that do not inherit from BaseException is not allowed".to_owned(), + "catching classes that do not inherit from BaseException is not allowed", )); } // Must not be a subclass of BaseExceptionGroup if exc_type.is_subclass(&base_eg, vm)? { return Err(vm.new_type_error( - "catching ExceptionGroup with except* is not allowed. Use except instead." - .to_owned(), + "catching ExceptionGroup with except* is not allowed. Use except instead.", )); } Ok(()) diff --git a/crates/vm/src/format.rs b/crates/vm/src/format.rs index 04d06e9be07..95bd893baea 100644 --- a/crates/vm/src/format.rs +++ b/crates/vm/src/format.rs @@ -9,6 +9,64 @@ use crate::{ use crate::common::format::*; use crate::common::wtf8::{Wtf8, Wtf8Buf}; +/// Get locale information from C `localeconv()` for the 'n' format specifier. +#[cfg(unix)] +pub(crate) fn get_locale_info() -> LocaleInfo { + use core::ffi::CStr; + unsafe { + let lc = libc::localeconv(); + if lc.is_null() { + return LocaleInfo { + thousands_sep: String::new(), + decimal_point: ".".to_string(), + grouping: vec![], + }; + } + let thousands_sep = CStr::from_ptr((*lc).thousands_sep) + .to_string_lossy() + .into_owned(); + let decimal_point = CStr::from_ptr((*lc).decimal_point) + .to_string_lossy() + .into_owned(); + let grouping = parse_grouping((*lc).grouping); + LocaleInfo { + thousands_sep, + decimal_point, + grouping, + } + } +} + +#[cfg(not(unix))] +pub(crate) fn get_locale_info() -> LocaleInfo { + LocaleInfo { + thousands_sep: String::new(), + decimal_point: ".".to_string(), + grouping: vec![], + } +} + +/// Parse C `lconv.grouping` into a `Vec`. +/// Reads bytes until 0 or CHAR_MAX, then appends 0 (meaning "repeat last group"). +#[cfg(unix)] +unsafe fn parse_grouping(grouping: *const libc::c_char) -> Vec { + let mut result = Vec::new(); + if grouping.is_null() { + return result; + } + unsafe { + let mut ptr = grouping; + while ![0, libc::c_char::MAX].contains(&*ptr) { + result.push(*ptr as u8); + ptr = ptr.add(1); + } + } + if !result.is_empty() { + result.push(0); + } + result +} + impl IntoPyException for FormatSpecError { fn into_pyexception(self, vm: &VirtualMachine) -> PyBaseExceptionRef { match self { diff --git a/crates/vm/src/frame.rs b/crates/vm/src/frame.rs index 376c9ed6bd1..b3c36277912 100644 --- a/crates/vm/src/frame.rs +++ b/crates/vm/src/frame.rs @@ -1,3 +1,5 @@ +// spell-checker: ignore compactlong compactlongs + use crate::anystr::AnyStr; #[cfg(feature = "flame")] use crate::bytecode::InstructionMetadata; @@ -12,7 +14,10 @@ use crate::{ builtin_func::PyNativeFunction, descriptor::{MemberGetter, PyMemberDescriptor, PyMethodDescriptor}, frame::stack_analysis, - function::{PyBoundMethod, PyCell, PyCellRef, PyFunction, vectorcall_function}, + function::{ + PyBoundMethod, PyCell, PyCellRef, PyFunction, datastack_frame_size_bytes_for_code, + vectorcall_function, + }, list::PyListIterator, range::PyRangeIterator, tuple::{PyTuple, PyTupleIterator, PyTupleRef}, @@ -23,13 +28,13 @@ use crate::{ convert::{ToPyObject, ToPyResult}, coroutine::Coro, exceptions::ExceptionCtor, - function::{ArgMapping, Either, FuncArgs}, + function::{ArgMapping, Either, FuncArgs, PyMethodFlags}, object::PyAtomicBorrow, object::{Traverse, TraverseFn}, protocol::{PyIter, PyIterReturn}, scope::Scope, sliceable::SliceableSequenceOp, - stdlib::{builtins, sys::monitoring, typing}, + stdlib::{_typing, builtins, sys::monitoring}, types::{PyComparisonOp, PyTypeFlags}, vm::{Context, PyMethod}, }; @@ -97,7 +102,7 @@ impl FrameOwner { /// a given frame (enforced by the owner field and generator running flag). /// External readers (e.g. `f_locals`) are on the same thread as execution /// (trace callback) or the frame is not executing. -struct FrameUnsafeCell(UnsafeCell); +pub(crate) struct FrameUnsafeCell(UnsafeCell); impl FrameUnsafeCell { fn new(value: T) -> Self { @@ -565,13 +570,18 @@ unsafe impl Traverse for FrameLocals { } } -#[pyclass(module = false, name = "frame", traverse = "manual")] -pub struct Frame { +/// Lightweight execution frame. Not a PyObject. +/// Analogous to CPython's `_PyInterpreterFrame`. +/// +/// Currently always embedded inside a `Frame` PyObject via `FrameUnsafeCell`. +/// In future PRs this will be usable independently for normal function calls +/// (allocated on the Rust stack + DataStack), eliminating PyObject overhead. +pub struct InterpreterFrame { pub code: PyRef, pub func_obj: Option, /// Unified storage for local variables and evaluation stack. - localsplus: FrameUnsafeCell, + pub(crate) localsplus: LocalsPlus, pub locals: FrameLocals, pub globals: PyDictRef, pub builtins: PyObjectRef, @@ -581,10 +591,8 @@ pub struct Frame { /// tracer function for this frame (usually is None) pub trace: PyMutex, - /// Cell and free variable references (cellvars + freevars). - cells_frees: FrameUnsafeCell>, /// Previous line number for LINE event suppression. - prev_line: FrameUnsafeCell, + pub(crate) prev_line: u32, // member pub trace_lines: PyMutex, @@ -613,6 +621,28 @@ pub struct Frame { pub(crate) pending_unwind_from_stack: PyAtomic, } +/// Python-visible frame object. Currently always wraps an `InterpreterFrame`. +/// Analogous to CPython's `PyFrameObject`. +#[pyclass(module = false, name = "frame", traverse = "manual")] +pub struct Frame { + pub(crate) iframe: FrameUnsafeCell, +} + +impl core::ops::Deref for Frame { + type Target = InterpreterFrame; + /// Transparent access to InterpreterFrame fields. + /// + /// # Safety argument + /// Immutable fields (code, globals, builtins, func_obj, locals) are safe + /// to access at any time. Atomic/mutex fields (lasti, trace, owner, etc.) + /// provide their own synchronization. Mutable fields (localsplus, prev_line) + /// are only mutated during single-threaded execution via `with_exec`. + #[inline(always)] + fn deref(&self) -> &InterpreterFrame { + unsafe { &*self.iframe.get() } + } +} + impl PyPayload for Frame { #[inline] fn class(ctx: &Context) -> &'static Py { @@ -622,18 +652,16 @@ impl PyPayload for Frame { unsafe impl Traverse for Frame { fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) { - self.code.traverse(tracer_fn); - self.func_obj.traverse(tracer_fn); // SAFETY: GC traversal does not run concurrently with frame execution. - unsafe { - (*self.localsplus.get()).traverse(tracer_fn); - (*self.cells_frees.get()).traverse(tracer_fn); - } - self.locals.traverse(tracer_fn); - self.globals.traverse(tracer_fn); - self.builtins.traverse(tracer_fn); - self.trace.traverse(tracer_fn); - self.temporary_refs.traverse(tracer_fn); + let iframe = unsafe { &*self.iframe.get() }; + iframe.code.traverse(tracer_fn); + iframe.func_obj.traverse(tracer_fn); + iframe.localsplus.traverse(tracer_fn); + iframe.locals.traverse(tracer_fn); + iframe.globals.traverse(tracer_fn); + iframe.builtins.traverse(tracer_fn); + iframe.trace.traverse(tracer_fn); + iframe.temporary_refs.traverse(tracer_fn); } } @@ -660,12 +688,6 @@ impl Frame { let num_cells = code.cellvars.len(); let nfrees = closure.len(); - let cells_frees: Box<[PyCellRef]> = - core::iter::repeat_with(|| PyCell::default().into_ref(&vm.ctx)) - .take(num_cells) - .chain(closure.iter().cloned()) - .collect(); - let nlocalsplus = nlocals .checked_add(num_cells) .and_then(|v| v.checked_add(nfrees)) @@ -677,13 +699,17 @@ impl Frame { LocalsPlus::new(nlocalsplus, max_stackdepth) }; - // Store cell objects at cellvars and freevars positions - for (i, cell) in cells_frees.iter().enumerate() { - localsplus.fastlocals_mut()[nlocals + i] = Some(cell.clone().into()); + // Store cell/free variable objects directly in localsplus + let fastlocals = localsplus.fastlocals_mut(); + for i in 0..num_cells { + fastlocals[nlocals + i] = Some(PyCell::default().into_ref(&vm.ctx).into()); + } + for (i, cell) in closure.iter().enumerate() { + fastlocals[nlocals + num_cells + i] = Some(cell.clone().into()); } - Self { - localsplus: FrameUnsafeCell::new(localsplus), + let iframe = InterpreterFrame { + localsplus, locals: match scope.locals { Some(locals) => FrameLocals::with_locals(locals), None if code.flags.contains(bytecode::CodeFlags::NEWLOCALS) => FrameLocals::lazy(), @@ -696,8 +722,7 @@ impl Frame { code, func_obj, lasti: Radium::new(0), - cells_frees: FrameUnsafeCell::new(cells_frees), - prev_line: FrameUnsafeCell::new(0), + prev_line: 0, trace: PyMutex::new(vm.ctx.none()), trace_lines: PyMutex::new(true), trace_opcodes: PyMutex::new(false), @@ -708,6 +733,9 @@ impl Frame { locals_dirty: atomic::AtomicBool::new(false), pending_stack_pops: Default::default(), pending_unwind_from_stack: Default::default(), + }; + Self { + iframe: FrameUnsafeCell::new(iframe), } } @@ -718,7 +746,7 @@ impl Frame { /// or called from the same thread during trace callback). #[inline(always)] pub unsafe fn fastlocals(&self) -> &[Option] { - unsafe { (*self.localsplus.get()).fastlocals() } + unsafe { (*self.iframe.get()).localsplus.fastlocals() } } /// Access fastlocals mutably. @@ -728,7 +756,7 @@ impl Frame { #[inline(always)] #[allow(clippy::mut_from_ref)] pub unsafe fn fastlocals_mut(&self) -> &mut [Option] { - unsafe { (*self.localsplus.get()).fastlocals_mut() } + unsafe { (*self.iframe.get()).localsplus.fastlocals_mut() } } /// Migrate data-stack-backed storage to the heap, preserving all values, @@ -739,16 +767,16 @@ impl Frame { /// Caller must ensure the frame is not executing and the returned /// pointer is passed to `VirtualMachine::datastack_pop()`. pub(crate) unsafe fn materialize_localsplus(&self) -> Option<*mut u8> { - unsafe { (*self.localsplus.get()).materialize_to_heap() } + unsafe { (*self.iframe.get()).localsplus.materialize_to_heap() } } /// Clear evaluation stack and state-owned cell/free references. /// For full local/cell cleanup, call `clear_locals_and_stack()`. pub(crate) fn clear_stack_and_cells(&self) { // SAFETY: Called when frame is not executing (generator closed). + // Cell refs in fastlocals[nlocals..] are cleared by clear_locals_and_stack(). unsafe { - (*self.localsplus.get()).stack_clear(); - let _old = core::mem::take(&mut *self.cells_frees.get()); + (*self.iframe.get()).localsplus.stack_clear(); } } @@ -757,7 +785,7 @@ impl Frame { pub(crate) fn clear_locals_and_stack(&self) { self.clear_stack_and_cells(); // SAFETY: Frame is not executing (generator closed). - let fastlocals = unsafe { (*self.localsplus.get()).fastlocals_mut() }; + let fastlocals = unsafe { (*self.iframe.get()).localsplus.fastlocals_mut() }; for slot in fastlocals.iter_mut() { *slot = None; } @@ -767,7 +795,7 @@ impl Frame { pub(crate) fn get_cell_contents(&self, cell_idx: usize) -> Option { let nlocals = self.code.varnames.len(); // SAFETY: Frame not executing; no concurrent mutation. - let fastlocals = unsafe { (*self.localsplus.get()).fastlocals() }; + let fastlocals = unsafe { (*self.iframe.get()).localsplus.fastlocals() }; fastlocals .get(nlocals + cell_idx) .and_then(|slot| slot.as_ref()) @@ -777,8 +805,14 @@ impl Frame { /// Set cell contents by cell index. Only safe to call before frame execution starts. pub(crate) fn set_cell_contents(&self, cell_idx: usize, value: Option) { + let nlocals = self.code.varnames.len(); // SAFETY: Called before frame execution starts. - unsafe { (*self.cells_frees.get())[cell_idx].set(value) }; + let fastlocals = unsafe { (*self.iframe.get()).localsplus.fastlocals() }; + fastlocals[nlocals + cell_idx] + .as_ref() + .and_then(|obj| obj.downcast_ref::()) + .expect("cell slot empty or not a PyCell") + .set(value); } /// Store a borrowed back-reference to the owning generator/coroutine. @@ -837,7 +871,7 @@ impl Frame { } let code = &**self.code; // SAFETY: Called before generator resume; no concurrent access. - let fastlocals = unsafe { (*self.localsplus.get()).fastlocals_mut() }; + let fastlocals = unsafe { (*self.iframe.get()).localsplus.fastlocals_mut() }; let locals_map = self.locals.mapping(vm); for (i, &varname) in code.varnames.iter().enumerate() { if i >= fastlocals.len() { @@ -862,7 +896,7 @@ impl Frame { let j = core::cmp::min(map.len(), code.varnames.len()); let locals_map = locals.mapping(vm); if !code.varnames.is_empty() { - let fastlocals = unsafe { (*self.localsplus.get()).fastlocals() }; + let fastlocals = unsafe { (*self.iframe.get()).localsplus.fastlocals() }; for (&k, v) in zip(&map[..j], fastlocals) { match locals_map.ass_subscript(k, v.clone(), vm) { Ok(()) => {} @@ -901,24 +935,25 @@ impl Py { // SAFETY: Frame execution is single-threaded. Only one thread at a time // executes a given frame (enforced by the owner field and generator // running flag). Same safety argument as FastLocals (UnsafeCell). + let iframe = unsafe { &mut *self.iframe.get() }; let exec = ExecutingFrame { - code: &self.code, - localsplus: unsafe { &mut *self.localsplus.get() }, - locals: &self.locals, - globals: &self.globals, - builtins: &self.builtins, - builtins_dict: if self.globals.class().is(vm.ctx.types.dict_type) { - self.builtins + code: &iframe.code, + localsplus: &mut iframe.localsplus, + locals: &iframe.locals, + globals: &iframe.globals, + builtins: &iframe.builtins, + builtins_dict: if iframe.globals.class().is(vm.ctx.types.dict_type) { + iframe + .builtins .downcast_ref_if_exact::(vm) // SAFETY: downcast_ref_if_exact already verified exact type .map(|d| unsafe { PyExact::ref_unchecked(d) }) } else { None }, - lasti: &self.lasti, + lasti: &iframe.lasti, object: self, - cells_frees: unsafe { &mut *self.cells_frees.get() }, - prev_line: unsafe { &mut *self.prev_line.get() }, + prev_line: &mut iframe.prev_line, monitoring_mask: 0, }; f(exec) @@ -960,17 +995,17 @@ impl Py { return None; } // SAFETY: Frame is not executing, so UnsafeCell access is safe. + let iframe = unsafe { &mut *self.iframe.get() }; let exec = ExecutingFrame { - code: &self.code, - localsplus: unsafe { &mut *self.localsplus.get() }, - locals: &self.locals, - globals: &self.globals, - builtins: &self.builtins, + code: &iframe.code, + localsplus: &mut iframe.localsplus, + locals: &iframe.locals, + globals: &iframe.globals, + builtins: &iframe.builtins, builtins_dict: None, - lasti: &self.lasti, + lasti: &iframe.lasti, object: self, - cells_frees: unsafe { &mut *self.cells_frees.get() }, - prev_line: unsafe { &mut *self.prev_line.get() }, + prev_line: &mut iframe.prev_line, monitoring_mask: 0, }; exec.yield_from_target().map(PyObject::to_owned) @@ -1010,12 +1045,213 @@ struct ExecutingFrame<'a> { builtins_dict: Option<&'a PyExact>, object: &'a Py, lasti: &'a PyAtomic, - cells_frees: &'a mut Box<[PyCellRef]>, prev_line: &'a mut u32, /// Cached monitoring events mask. Reloaded at Resume instruction only, monitoring_mask: u32, } +#[inline] +fn specialization_compact_int_value(i: &PyInt, vm: &VirtualMachine) -> Option { + // _PyLong_IsCompact(): a one-digit PyLong (base 2^30), + // i.e. abs(value) <= 2^30 - 1. + const CPYTHON_COMPACT_LONG_ABS_MAX: i64 = (1i64 << 30) - 1; + let v = i.try_to_primitive::(vm).ok()?; + if (-CPYTHON_COMPACT_LONG_ABS_MAX..=CPYTHON_COMPACT_LONG_ABS_MAX).contains(&v) { + Some(v as isize) + } else { + None + } +} + +#[inline] +fn compact_int_from_obj(obj: &PyObject, vm: &VirtualMachine) -> Option { + obj.downcast_ref_if_exact::(vm) + .and_then(|i| specialization_compact_int_value(i, vm)) +} + +#[inline] +fn exact_float_from_obj(obj: &PyObject, vm: &VirtualMachine) -> Option { + obj.downcast_ref_if_exact::(vm).map(|f| f.to_f64()) +} + +#[inline] +fn specialization_nonnegative_compact_index(i: &PyInt, vm: &VirtualMachine) -> Option { + // _PyLong_IsNonNegativeCompact(): a single base-2^30 digit. + const CPYTHON_COMPACT_LONG_MAX: u64 = (1u64 << 30) - 1; + let v = i.try_to_primitive::(vm).ok()?; + if v <= CPYTHON_COMPACT_LONG_MAX { + Some(v as usize) + } else { + None + } +} + +fn release_datastack_frame(frame: &Py, vm: &VirtualMachine) { + unsafe { + if let Some(base) = frame.materialize_localsplus() { + vm.datastack_pop(base); + } + } +} + +type BinaryOpExtendGuard = fn(&PyObject, &PyObject, &VirtualMachine) -> bool; +type BinaryOpExtendAction = fn(&PyObject, &PyObject, &VirtualMachine) -> Option; + +struct BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator, + guard: BinaryOpExtendGuard, + action: BinaryOpExtendAction, +} + +const BINARY_OP_EXTEND_EXTERNAL_CACHE_OFFSET: usize = 1; + +#[inline] +fn compactlongs_guard(lhs: &PyObject, rhs: &PyObject, vm: &VirtualMachine) -> bool { + compact_int_from_obj(lhs, vm).is_some() && compact_int_from_obj(rhs, vm).is_some() +} + +macro_rules! bitwise_longs_action { + ($name:ident, $op:tt) => { + #[inline] + fn $name(lhs: &PyObject, rhs: &PyObject, vm: &VirtualMachine) -> Option { + let lhs_val = compact_int_from_obj(lhs, vm)?; + let rhs_val = compact_int_from_obj(rhs, vm)?; + Some(vm.ctx.new_int(lhs_val $op rhs_val).into()) + } + }; +} +bitwise_longs_action!(compactlongs_or, |); +bitwise_longs_action!(compactlongs_and, &); +bitwise_longs_action!(compactlongs_xor, ^); + +#[inline] +fn float_compactlong_guard(lhs: &PyObject, rhs: &PyObject, vm: &VirtualMachine) -> bool { + exact_float_from_obj(lhs, vm).is_some_and(|f| !f.is_nan()) + && compact_int_from_obj(rhs, vm).is_some() +} + +#[inline] +fn nonzero_float_compactlong_guard(lhs: &PyObject, rhs: &PyObject, vm: &VirtualMachine) -> bool { + float_compactlong_guard(lhs, rhs, vm) && compact_int_from_obj(rhs, vm).is_some_and(|v| v != 0) +} + +macro_rules! float_long_action { + ($name:ident, $op:tt) => { + #[inline] + fn $name(lhs: &PyObject, rhs: &PyObject, vm: &VirtualMachine) -> Option { + let lhs_val = exact_float_from_obj(lhs, vm)?; + let rhs_val = compact_int_from_obj(rhs, vm)?; + Some(vm.ctx.new_float(lhs_val $op rhs_val as f64).into()) + } + }; +} +float_long_action!(float_compactlong_add, +); +float_long_action!(float_compactlong_subtract, -); +float_long_action!(float_compactlong_multiply, *); +float_long_action!(float_compactlong_true_div, /); + +#[inline] +fn compactlong_float_guard(lhs: &PyObject, rhs: &PyObject, vm: &VirtualMachine) -> bool { + compact_int_from_obj(lhs, vm).is_some() + && exact_float_from_obj(rhs, vm).is_some_and(|f| !f.is_nan()) +} + +#[inline] +fn nonzero_compactlong_float_guard(lhs: &PyObject, rhs: &PyObject, vm: &VirtualMachine) -> bool { + compactlong_float_guard(lhs, rhs, vm) && exact_float_from_obj(rhs, vm).is_some_and(|f| f != 0.0) +} + +macro_rules! long_float_action { + ($name:ident, $op:tt) => { + #[inline] + fn $name(lhs: &PyObject, rhs: &PyObject, vm: &VirtualMachine) -> Option { + let lhs_val = compact_int_from_obj(lhs, vm)?; + let rhs_val = exact_float_from_obj(rhs, vm)?; + Some(vm.ctx.new_float(lhs_val as f64 $op rhs_val).into()) + } + }; +} +long_float_action!(compactlong_float_add, +); +long_float_action!(compactlong_float_subtract, -); +long_float_action!(compactlong_float_multiply, *); +long_float_action!(compactlong_float_true_div, /); + +static BINARY_OP_EXTEND_DESCRIPTORS: &[BinaryOpExtendSpecializationDescr] = &[ + // long-long arithmetic + BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator::Or, + guard: compactlongs_guard, + action: compactlongs_or, + }, + BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator::And, + guard: compactlongs_guard, + action: compactlongs_and, + }, + BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator::Xor, + guard: compactlongs_guard, + action: compactlongs_xor, + }, + BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator::InplaceOr, + guard: compactlongs_guard, + action: compactlongs_or, + }, + BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator::InplaceAnd, + guard: compactlongs_guard, + action: compactlongs_and, + }, + BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator::InplaceXor, + guard: compactlongs_guard, + action: compactlongs_xor, + }, + // float-long arithmetic + BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator::Add, + guard: float_compactlong_guard, + action: float_compactlong_add, + }, + BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator::Subtract, + guard: float_compactlong_guard, + action: float_compactlong_subtract, + }, + BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator::TrueDivide, + guard: nonzero_float_compactlong_guard, + action: float_compactlong_true_div, + }, + BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator::Multiply, + guard: float_compactlong_guard, + action: float_compactlong_multiply, + }, + // long-float arithmetic + BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator::Add, + guard: compactlong_float_guard, + action: compactlong_float_add, + }, + BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator::Subtract, + guard: compactlong_float_guard, + action: compactlong_float_subtract, + }, + BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator::TrueDivide, + guard: nonzero_compactlong_float_guard, + action: compactlong_float_true_div, + }, + BinaryOpExtendSpecializationDescr { + oparg: bytecode::BinaryOperator::Multiply, + guard: compactlong_float_guard, + action: compactlong_float_multiply, + }, +]; + impl fmt::Debug for ExecutingFrame<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ExecutingFrame") @@ -1026,6 +1262,57 @@ impl fmt::Debug for ExecutingFrame<'_> { } impl ExecutingFrame<'_> { + #[inline] + fn monitoring_disabled_for_code(&self, vm: &VirtualMachine) -> bool { + self.code.is(&vm.ctx.init_cleanup_code) + } + + fn specialization_new_init_cleanup_frame(&self, vm: &VirtualMachine) -> FrameRef { + Frame::new( + vm.ctx.init_cleanup_code.clone(), + Scope::new( + Some(ArgMapping::from_dict_exact(vm.ctx.new_dict())), + self.globals.clone(), + ), + self.builtins.clone(), + &[], + None, + true, + vm, + ) + .into_ref(&vm.ctx) + } + + fn specialization_run_init_cleanup_shim( + &self, + new_obj: PyObjectRef, + init_func: &Py, + pos_args: Vec, + vm: &VirtualMachine, + ) -> PyResult { + let shim = self.specialization_new_init_cleanup_frame(vm); + let shim_result = vm.with_frame_untraced(shim.clone(), |shim| { + shim.with_exec(vm, |mut exec| exec.push_value(new_obj.clone())); + + let mut all_args = Vec::with_capacity(pos_args.len() + 1); + all_args.push(new_obj.clone()); + all_args.extend(pos_args); + + let init_frame = init_func.prepare_exact_args_frame(all_args, vm); + let init_result = vm.run_frame(init_frame.clone()); + release_datastack_frame(&init_frame, vm); + let init_result = init_result?; + + shim.with_exec(vm, |mut exec| exec.push_value(init_result)); + match shim.run(vm)? { + ExecutionResult::Return(value) => Ok(value), + ExecutionResult::Yield(_) => unreachable!("_Py_InitCleanup shim cannot yield"), + } + }); + release_datastack_frame(&shim, vm); + shim_result + } + #[inline(always)] fn update_lasti(&mut self, f: impl FnOnce(&mut u32)) { let mut val = self.lasti.load(Relaxed); @@ -1038,6 +1325,18 @@ impl ExecutingFrame<'_> { self.lasti.load(Relaxed) } + /// Access the PyCellRef at the given cell/free variable index. + /// `cell_idx` is 0-based: 0..ncells for cellvars, ncells.. for freevars. + #[inline(always)] + fn cell_ref(&self, cell_idx: usize) -> &PyCell { + let nlocals = self.code.varnames.len(); + self.localsplus.fastlocals()[nlocals + cell_idx] + .as_ref() + .expect("cell slot empty") + .downcast_ref::() + .expect("cell slot is not a PyCell") + } + /// Perform deferred stack unwinding after set_f_lineno. /// /// set_f_lineno cannot pop the value stack directly because the execution @@ -1158,7 +1457,9 @@ impl ExecutingFrame<'_> { } } - if let Err(exception) = vm.check_signals() { + if vm.eval_breaker_tripped() + && let Err(exception) = vm.check_signals() + { #[cold] fn handle_signal_exception( frame: &mut ExecutingFrame<'_>, @@ -1626,22 +1927,44 @@ impl ExecutingFrame<'_> { self.adaptive(|s, ii, cb| s.specialize_binary_op(vm, op_val, ii, cb)); self.execute_bin_op(vm, op_val) } - // TODO: In CPython, this does in-place unicode concatenation when - // refcount is 1. Falls back to regular iadd for now. + // Super-instruction for BINARY_OP_ADD_UNICODE + STORE_FAST targeting + // the left local, matching BINARY_OP_INPLACE_ADD_UNICODE shape. Instruction::BinaryOpInplaceAddUnicode => { let b = self.top_value(); let a = self.nth_value(1); - if let (Some(a_str), Some(b_str)) = ( + let instr_idx = self.lasti() as usize - 1; + let cache_base = instr_idx + 1; + let target_local = self.binary_op_inplace_unicode_target_local(cache_base, a); + if let (Some(_a_str), Some(_b_str), Some(target_local)) = ( a.downcast_ref_if_exact::(vm), b.downcast_ref_if_exact::(vm), + target_local, ) { - let result = a_str.as_wtf8().py_add(b_str.as_wtf8()); - self.pop_value(); - self.pop_value(); - self.push_value(result.to_pyobject(vm)); + let right = self.pop_value(); + let left = self.pop_value(); + + let local_obj = self.localsplus.fastlocals_mut()[target_local] + .take() + .expect("BINARY_OP_INPLACE_ADD_UNICODE target local missing"); + debug_assert!(local_obj.is(&left)); + let mut local_str = local_obj + .downcast_exact::(vm) + .expect("BINARY_OP_INPLACE_ADD_UNICODE target local not exact str") + .into_pyref(); + drop(left); + let right_str = right + .downcast_ref_if_exact::(vm) + .expect("BINARY_OP_INPLACE_ADD_UNICODE right operand not exact str"); + local_str.concat_in_place(right_str.as_wtf8(), vm); + + self.localsplus.fastlocals_mut()[target_local] = Some(local_str.into()); + self.jump_relative_forward( + 1, + Instruction::BinaryOpInplaceAddUnicode.cache_entries() as u32, + ); Ok(None) } else { - self.execute_bin_op(vm, bytecode::BinaryOperator::InplaceAdd) + self.execute_bin_op(vm, self.binary_op_from_arg(arg)) } } Instruction::BinarySlice => { @@ -1836,12 +2159,12 @@ impl ExecutingFrame<'_> { } Instruction::DeleteAttr { namei: idx } => self.delete_attr(vm, idx.get(arg)), Instruction::DeleteDeref { i } => { - self.cells_frees[i.get(arg) as usize].set(None); + self.cell_ref(i.get(arg) as usize).set(None); Ok(None) } - Instruction::DeleteFast { var_num: idx } => { + Instruction::DeleteFast { var_num } => { let fastlocals = self.localsplus.fastlocals_mut(); - let idx = idx.get(arg) as usize; + let idx = var_num.get(arg); if fastlocals[idx].is_none() { return Err(vm.new_exception_msg( vm.ctx.exceptions.unbound_local_error.to_owned(), @@ -1989,7 +2312,7 @@ impl ExecutingFrame<'_> { Instruction::ForIter { .. } => { // Relative forward jump: target = lasti + caches + delta let target = bytecode::Label(self.lasti() + 1 + u32::from(arg)); - self.adaptive(|s, ii, cb| s.specialize_for_iter(vm, ii, cb)); + self.adaptive(|s, ii, cb| s.specialize_for_iter(vm, u32::from(arg), ii, cb)); self.execute_for_iter(vm, target)?; Ok(None) } @@ -2094,9 +2417,7 @@ impl ExecutingFrame<'_> { if let Some(coro) = iter.downcast_ref::() && coro.as_coro().frame().yield_from_target().is_some() { - return Err( - vm.new_runtime_error("coroutine is being awaited already".to_owned()) - ); + return Err(vm.new_runtime_error("coroutine is being awaited already")); } self.push_value(iter); @@ -2120,8 +2441,7 @@ impl ExecutingFrame<'_> { bytecode::CodeFlags::COROUTINE | bytecode::CodeFlags::ITERABLE_COROUTINE, ) { return Err(vm.new_type_error( - "cannot 'yield from' a coroutine object in a non-coroutine generator" - .to_owned(), + "cannot 'yield from' a coroutine object in a non-coroutine generator", )); } iterable @@ -2235,7 +2555,7 @@ impl ExecutingFrame<'_> { .get_item_opt(identifier!(vm, __build_class__), vm)? .ok_or_else(|| { vm.new_name_error( - "__build_class__ not found".to_owned(), + "__build_class__ not found", identifier!(vm, __build_class__).to_owned(), ) })? @@ -2245,7 +2565,7 @@ impl ExecutingFrame<'_> { .map_err(|e| { if e.fast_isinstance(vm.ctx.exceptions.key_error) { vm.new_name_error( - "__build_class__ not found".to_owned(), + "__build_class__ not found", identifier!(vm, __build_class__).to_owned(), ) } else { @@ -2283,7 +2603,8 @@ impl ExecutingFrame<'_> { }; self.push_value(match value { Some(v) => v, - None => self.cells_frees[i] + None => self + .cell_ref(i) .get() .ok_or_else(|| self.unbound_cell_exception(i, vm))?, }); @@ -2312,8 +2633,8 @@ impl ExecutingFrame<'_> { }); Ok(None) } - Instruction::LoadConst { consti: idx } => { - self.push_value(self.code.constants[idx.get(arg) as usize].clone().into()); + Instruction::LoadConst { consti } => { + self.push_value(self.code.constants[consti.get(arg)].clone().into()); // Mirror CPython's LOAD_CONST family transition. RustPython does // not currently distinguish immortal constants at runtime. let instr_idx = self.lasti() as usize - 1; @@ -2325,7 +2646,7 @@ impl ExecutingFrame<'_> { Ok(None) } Instruction::LoadConstMortal | Instruction::LoadConstImmortal => { - self.push_value(self.code.constants[u32::from(arg) as usize].clone().into()); + self.push_value(self.code.constants[u32::from(arg).into()].clone().into()); Ok(None) } Instruction::LoadCommonConstant { idx } => { @@ -2352,13 +2673,14 @@ impl ExecutingFrame<'_> { } Instruction::LoadDeref { i } => { let idx = i.get(arg) as usize; - let x = self.cells_frees[idx] + let x = self + .cell_ref(idx) .get() .ok_or_else(|| self.unbound_cell_exception(idx, vm))?; self.push_value(x); Ok(None) } - Instruction::LoadFast { var_num: idx } => { + Instruction::LoadFast { var_num } => { #[cold] fn reference_error( varname: &'static PyStrInterned, @@ -2369,27 +2691,27 @@ impl ExecutingFrame<'_> { format!("local variable '{varname}' referenced before assignment").into(), ) } - let idx = idx.get(arg) as usize; + let idx = var_num.get(arg); let x = self.localsplus.fastlocals()[idx] .clone() .ok_or_else(|| reference_error(self.code.varnames[idx], vm))?; self.push_value(x); Ok(None) } - Instruction::LoadFastAndClear { var_num: idx } => { + Instruction::LoadFastAndClear { var_num } => { // Load value and clear the slot (for inlined comprehensions) // If slot is empty, push None (not an error - variable may not exist yet) - let idx = idx.get(arg) as usize; + let idx = var_num.get(arg); let x = self.localsplus.fastlocals_mut()[idx] .take() .unwrap_or_else(|| vm.ctx.none()); self.push_value(x); Ok(None) } - Instruction::LoadFastCheck { var_num: idx } => { + Instruction::LoadFastCheck { var_num } => { // Same as LoadFast but explicitly checks for unbound locals // (LoadFast in RustPython already does this check) - let idx = idx.get(arg) as usize; + let idx = var_num.get(arg); let x = self.localsplus.fastlocals()[idx].clone().ok_or_else(|| { vm.new_exception_msg( vm.ctx.exceptions.unbound_local_error.to_owned(), @@ -2437,8 +2759,8 @@ impl ExecutingFrame<'_> { // Borrow optimization not yet active; falls back to clone. // push_borrowed() is available but disabled until stack // lifetime issues at yield/exception points are resolved. - Instruction::LoadFastBorrow { var_num: idx } => { - let idx = idx.get(arg) as usize; + Instruction::LoadFastBorrow { var_num } => { + let idx = var_num.get(arg); let x = self.localsplus.fastlocals()[idx].clone().ok_or_else(|| { vm.new_exception_msg( vm.ctx.exceptions.unbound_local_error.to_owned(), @@ -2604,7 +2926,7 @@ impl ExecutingFrame<'_> { Some(s) => s, None => { return Err(vm.new_type_error( - "__match_args__ elements must be strings".to_string(), + "__match_args__ elements must be strings", )); } }; @@ -2635,16 +2957,14 @@ impl ExecutingFrame<'_> { } else if nargs_val > 1 { // Too many positional arguments for MATCH_SELF return Err(vm.new_type_error( - "class pattern accepts at most 1 positional sub-pattern for MATCH_SELF types" - .to_string(), + "class pattern accepts at most 1 positional sub-pattern for MATCH_SELF types", )); } } else { // No __match_args__ and not a MATCH_SELF type if nargs_val > 0 { return Err(vm.new_type_error( - "class pattern defines no positional sub-patterns (__match_args__ missing)" - .to_string(), + "class pattern defines no positional sub-patterns (__match_args__ missing)", )); } } @@ -2834,6 +3154,17 @@ impl ExecutingFrame<'_> { self.code.instructions.quicken(); atomic::fence(atomic::Ordering::Release); } + if self.monitoring_disabled_for_code(vm) { + let global_ver = vm + .state + .instrumentation_version + .load(atomic::Ordering::Acquire); + monitoring::instrument_code(self.code, 0); + self.code + .instrumentation_version + .store(global_ver, atomic::Ordering::Release); + return Ok(None); + } // Check if bytecode needs re-instrumentation let global_ver = vm .state @@ -2965,13 +3296,13 @@ impl ExecutingFrame<'_> { } Instruction::StoreDeref { i } => { let value = self.pop_value(); - self.cells_frees[i.get(arg) as usize].set(Some(value)); + self.cell_ref(i.get(arg) as usize).set(Some(value)); Ok(None) } - Instruction::StoreFast { var_num: idx } => { + Instruction::StoreFast { var_num } => { let value = self.pop_value(); let fastlocals = self.localsplus.fastlocals_mut(); - fastlocals[idx.get(arg) as usize] = Some(value); + fastlocals[var_num.get(arg)] = Some(value); Ok(None) } Instruction::StoreFastLoadFast { var_nums } => { @@ -3059,8 +3390,9 @@ impl ExecutingFrame<'_> { self.execute_unpack_ex(vm, args.before, args.after) } Instruction::UnpackSequence { count: size } => { - self.adaptive(|s, ii, cb| s.specialize_unpack_sequence(vm, ii, cb)); - self.unpack_sequence(size.get(arg), vm) + let expected = size.get(arg); + self.adaptive(|s, ii, cb| s.specialize_unpack_sequence(vm, expected, ii, cb)); + self.unpack_sequence(expected, vm) } Instruction::WithExceptStart => { // Stack: [..., __exit__, lasti, prev_exc, exc] @@ -3107,12 +3439,28 @@ impl ExecutingFrame<'_> { } Instruction::Send { .. } => { // (receiver, v -- receiver, retval) - self.adaptive(|s, ii, cb| s.specialize_send(ii, cb)); + self.adaptive(|s, ii, cb| s.specialize_send(vm, ii, cb)); let exit_label = bytecode::Label(self.lasti() + 1 + u32::from(arg)); + let receiver = self.nth_value(1); + let can_fast_send = !self.specialization_eval_frame_active(vm) + && (receiver.downcast_ref_if_exact::(vm).is_some() + || receiver.downcast_ref_if_exact::(vm).is_some()) + && self + .builtin_coro(receiver) + .is_some_and(|coro| !coro.running() && !coro.closed()); let val = self.pop_value(); let receiver = self.top_value(); - - match self._send(receiver, val, vm)? { + let ret = if can_fast_send { + let coro = self.builtin_coro(receiver).unwrap(); + if vm.is_none(&val) { + coro.send_none(receiver, vm)? + } else { + coro.send(receiver, val, vm)? + } + } else { + self._send(receiver, val, vm)? + }; + match ret { PyIterReturn::Return(value) => { self.push_value(value); Ok(None) @@ -3133,13 +3481,23 @@ impl ExecutingFrame<'_> { let exit_label = bytecode::Label(self.lasti() + 1 + u32::from(arg)); // Stack: [receiver, val] — peek receiver before popping let receiver = self.nth_value(1); - let is_coro = self.builtin_coro(receiver).is_some(); + let can_fast_send = !self.specialization_eval_frame_active(vm) + && (receiver.downcast_ref_if_exact::(vm).is_some() + || receiver.downcast_ref_if_exact::(vm).is_some()) + && self + .builtin_coro(receiver) + .is_some_and(|coro| !coro.running() && !coro.closed()); let val = self.pop_value(); - let receiver = self.top_value(); - if is_coro { + if can_fast_send { + let receiver = self.top_value(); let coro = self.builtin_coro(receiver).unwrap(); - match coro.send(receiver, val, vm)? { + let ret = if vm.is_none(&val) { + coro.send_none(receiver, vm)? + } else { + coro.send(receiver, val, vm)? + }; + match ret { PyIterReturn::Return(value) => { self.push_value(value); return Ok(None); @@ -3156,6 +3514,7 @@ impl ExecutingFrame<'_> { } } } + let receiver = self.top_value(); match self._send(receiver, val, vm)? { PyIterReturn::Return(value) => { self.push_value(value); @@ -3227,7 +3586,7 @@ impl ExecutingFrame<'_> { let exc = exc .downcast::() - .map_err(|_| vm.new_type_error("exception expected".to_owned()))?; + .map_err(|_| vm.new_type_error("exception expected"))?; Err(exc) } Instruction::UnaryInvert => { @@ -3251,8 +3610,7 @@ impl ExecutingFrame<'_> { // Specialized LOAD_ATTR opcodes Instruction::LoadAttrMethodNoDict => { let oparg = LoadAttr::new(u32::from(arg)); - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; + let cache_base = self.lasti() as usize; let owner = self.top_value(); let type_version = self.code.instructions.read_cache_u32(cache_base + 1); @@ -3266,20 +3624,12 @@ impl ExecutingFrame<'_> { self.push_value(owner); Ok(None) } else { - self.deoptimize_at( - Instruction::LoadAttr { - namei: Arg::marker(), - }, - instr_idx, - cache_base, - ); self.load_attr_slow(vm, oparg) } } Instruction::LoadAttrMethodLazyDict => { let oparg = LoadAttr::new(u32::from(arg)); - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; + let cache_base = self.lasti() as usize; let owner = self.top_value(); let type_version = self.code.instructions.read_cache_u32(cache_base + 1); @@ -3294,20 +3644,12 @@ impl ExecutingFrame<'_> { self.push_value(owner); Ok(None) } else { - self.deoptimize_at( - Instruction::LoadAttr { - namei: Arg::marker(), - }, - instr_idx, - cache_base, - ); self.load_attr_slow(vm, oparg) } } Instruction::LoadAttrMethodWithValues => { let oparg = LoadAttr::new(u32::from(arg)); - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; + let cache_base = self.lasti() as usize; let attr_name = self.code.names[oparg.name_idx() as usize]; let owner = self.top_value(); @@ -3320,23 +3662,7 @@ impl ExecutingFrame<'_> { Ok(Some(_)) => true, Ok(None) => false, Err(_) => { - // Dict lookup error → deoptimize to safe path - unsafe { - self.code.instructions.replace_op( - instr_idx, - Instruction::LoadAttr { - namei: Arg::marker(), - }, - ); - self.code.instructions.write_adaptive_counter( - cache_base, - bytecode::adaptive_counter_backoff( - self.code - .instructions - .read_adaptive_counter(cache_base), - ), - ); - } + // Dict lookup error -> use safe path. return self.load_attr_slow(vm, oparg); } } @@ -3354,19 +3680,11 @@ impl ExecutingFrame<'_> { return Ok(None); } } - self.deoptimize_at( - Instruction::LoadAttr { - namei: Arg::marker(), - }, - instr_idx, - cache_base, - ); self.load_attr_slow(vm, oparg) } Instruction::LoadAttrInstanceValue => { let oparg = LoadAttr::new(u32::from(arg)); - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; + let cache_base = self.lasti() as usize; let attr_name = self.code.names[oparg.name_idx() as usize]; let owner = self.top_value(); @@ -3384,19 +3702,11 @@ impl ExecutingFrame<'_> { } // Not in instance dict — fall through to class lookup via slow path } - self.deoptimize_at( - Instruction::LoadAttr { - namei: Arg::marker(), - }, - instr_idx, - cache_base, - ); self.load_attr_slow(vm, oparg) } Instruction::LoadAttrWithHint => { let oparg = LoadAttr::new(u32::from(arg)); - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; + let cache_base = self.lasti() as usize; let attr_name = self.code.names[oparg.name_idx() as usize]; let owner = self.top_value(); @@ -3417,19 +3727,11 @@ impl ExecutingFrame<'_> { return Ok(None); } - self.deoptimize_at( - Instruction::LoadAttr { - namei: Arg::marker(), - }, - instr_idx, - cache_base, - ); self.load_attr_slow(vm, oparg) } Instruction::LoadAttrModule => { let oparg = LoadAttr::new(u32::from(arg)); - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; + let cache_base = self.lasti() as usize; let attr_name = self.code.names[oparg.name_idx() as usize]; let owner = self.top_value(); @@ -3449,27 +3751,11 @@ impl ExecutingFrame<'_> { } return Ok(None); } - // Deoptimize - unsafe { - self.code.instructions.replace_op( - instr_idx, - Instruction::LoadAttr { - namei: Arg::marker(), - }, - ); - self.code.instructions.write_adaptive_counter( - cache_base, - bytecode::adaptive_counter_backoff( - self.code.instructions.read_adaptive_counter(cache_base), - ), - ); - } self.load_attr_slow(vm, oparg) } Instruction::LoadAttrNondescriptorNoDict => { let oparg = LoadAttr::new(u32::from(arg)); - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; + let cache_base = self.lasti() as usize; let owner = self.top_value(); let type_version = self.code.instructions.read_cache_u32(cache_base + 1); @@ -3487,26 +3773,11 @@ impl ExecutingFrame<'_> { } return Ok(None); } - unsafe { - self.code.instructions.replace_op( - instr_idx, - Instruction::LoadAttr { - namei: Arg::marker(), - }, - ); - self.code.instructions.write_adaptive_counter( - cache_base, - bytecode::adaptive_counter_backoff( - self.code.instructions.read_adaptive_counter(cache_base), - ), - ); - } self.load_attr_slow(vm, oparg) } Instruction::LoadAttrNondescriptorWithValues => { let oparg = LoadAttr::new(u32::from(arg)); - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; + let cache_base = self.lasti() as usize; let attr_name = self.code.names[oparg.name_idx() as usize]; let owner = self.top_value(); @@ -3529,13 +3800,6 @@ impl ExecutingFrame<'_> { // Not in instance dict — use cached class attr let Some(attr) = self.try_read_cached_descriptor(cache_base, type_version) else { - self.deoptimize_at( - Instruction::LoadAttr { - namei: Arg::marker(), - }, - instr_idx, - cache_base, - ); return self.load_attr_slow(vm, oparg); }; self.pop_value(); @@ -3547,26 +3811,11 @@ impl ExecutingFrame<'_> { } return Ok(None); } - unsafe { - self.code.instructions.replace_op( - instr_idx, - Instruction::LoadAttr { - namei: Arg::marker(), - }, - ); - self.code.instructions.write_adaptive_counter( - cache_base, - bytecode::adaptive_counter_backoff( - self.code.instructions.read_adaptive_counter(cache_base), - ), - ); - } self.load_attr_slow(vm, oparg) } Instruction::LoadAttrClass => { let oparg = LoadAttr::new(u32::from(arg)); - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; + let cache_base = self.lasti() as usize; let owner = self.top_value(); let type_version = self.code.instructions.read_cache_u32(cache_base + 1); @@ -3585,26 +3834,11 @@ impl ExecutingFrame<'_> { } return Ok(None); } - unsafe { - self.code.instructions.replace_op( - instr_idx, - Instruction::LoadAttr { - namei: Arg::marker(), - }, - ); - self.code.instructions.write_adaptive_counter( - cache_base, - bytecode::adaptive_counter_backoff( - self.code.instructions.read_adaptive_counter(cache_base), - ), - ); - } self.load_attr_slow(vm, oparg) } Instruction::LoadAttrClassWithMetaclassCheck => { let oparg = LoadAttr::new(u32::from(arg)); - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; + let cache_base = self.lasti() as usize; let owner = self.top_value(); let type_version = self.code.instructions.read_cache_u32(cache_base + 1); @@ -3626,26 +3860,38 @@ impl ExecutingFrame<'_> { } return Ok(None); } - self.deoptimize_at( - Instruction::LoadAttr { - namei: Arg::marker(), - }, - instr_idx, - cache_base, - ); self.load_attr_slow(vm, oparg) } Instruction::LoadAttrGetattributeOverridden => { let oparg = LoadAttr::new(u32::from(arg)); - self.deoptimize(Instruction::LoadAttr { - namei: Arg::marker(), - }); + let cache_base = self.lasti() as usize; + let owner = self.top_value(); + let type_version = self.code.instructions.read_cache_u32(cache_base + 1); + let func_version = self.code.instructions.read_cache_u32(cache_base + 3); + + if !oparg.is_method() + && !self.specialization_eval_frame_active(vm) + && type_version != 0 + && func_version != 0 + && owner.class().tp_version_tag.load(Acquire) == type_version + && let Some(func_obj) = + self.try_read_cached_descriptor(cache_base, type_version) + && let Some(func) = func_obj.downcast_ref_if_exact::(vm) + && func.func_version() == func_version + && self.specialization_has_datastack_space_for_func(vm, func) + { + debug_assert!(func.has_exact_argcount(2)); + let owner = self.pop_value(); + let attr_name = self.code.names[oparg.name_idx() as usize].to_owned().into(); + let result = func.invoke_exact_args(vec![owner, attr_name], vm)?; + self.push_value(result); + return Ok(None); + } self.load_attr_slow(vm, oparg) } Instruction::LoadAttrSlot => { let oparg = LoadAttr::new(u32::from(arg)); - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; + let cache_base = self.lasti() as usize; let owner = self.top_value(); let type_version = self.code.instructions.read_cache_u32(cache_base + 1); @@ -3665,55 +3911,29 @@ impl ExecutingFrame<'_> { } // Slot is None → AttributeError (fall through to slow path) } - unsafe { - self.code.instructions.replace_op( - instr_idx, - Instruction::LoadAttr { - namei: Arg::marker(), - }, - ); - self.code.instructions.write_adaptive_counter( - cache_base, - bytecode::adaptive_counter_backoff( - self.code.instructions.read_adaptive_counter(cache_base), - ), - ); - } self.load_attr_slow(vm, oparg) } Instruction::LoadAttrProperty => { let oparg = LoadAttr::new(u32::from(arg)); - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; + let cache_base = self.lasti() as usize; let owner = self.top_value(); let type_version = self.code.instructions.read_cache_u32(cache_base + 1); if type_version != 0 + && !self.specialization_eval_frame_active(vm) && owner.class().tp_version_tag.load(Acquire) == type_version - && let Some(descr) = self.try_read_cached_descriptor(cache_base, type_version) - && let Some(prop) = descr.downcast_ref::() - && let Some(getter) = prop.get_fget() + && let Some(fget_obj) = + self.try_read_cached_descriptor(cache_base, type_version) + && let Some(func) = fget_obj.downcast_ref_if_exact::(vm) + && func.can_specialize_call(1) + && self.specialization_has_datastack_space_for_func(vm, func) { let owner = self.pop_value(); - let result = getter.call((owner,), vm)?; + let result = func.invoke_exact_args(vec![owner], vm)?; self.push_value(result); return Ok(None); } - unsafe { - self.code.instructions.replace_op( - instr_idx, - Instruction::LoadAttr { - namei: Arg::marker(), - }, - ); - self.code.instructions.write_adaptive_counter( - cache_base, - bytecode::adaptive_counter_backoff( - self.code.instructions.read_adaptive_counter(cache_base), - ), - ); - } self.load_attr_slow(vm, oparg) } Instruction::StoreAttrInstanceValue => { @@ -3781,15 +4001,13 @@ impl ExecutingFrame<'_> { let value = self.pop_value(); if let Some(list) = obj.downcast_ref_if_exact::(vm) && let Some(int_idx) = idx.downcast_ref_if_exact::(vm) - && let Ok(i) = int_idx.try_to_primitive::(vm) + && let Some(i) = specialization_nonnegative_compact_index(int_idx, vm) { let mut vec = list.borrow_vec_mut(); - if let Some(pos) = vec.wrap_index(i) { - vec[pos] = value; + if i < vec.len() { + vec[i] = value; return Ok(None); } - drop(vec); - return Err(vm.new_index_error("list assignment index out of range")); } obj.set_item(&*idx, value, vm)?; Ok(None) @@ -3842,10 +4060,40 @@ impl ExecutingFrame<'_> { self.execute_bin_op(vm, bytecode::BinaryOperator::Add) } } - Instruction::BinaryOpSubscrGetitem | Instruction::BinaryOpExtend => { - let op = bytecode::BinaryOperator::try_from(u32::from(arg)) - .unwrap_or(bytecode::BinaryOperator::Subscr); - self.execute_bin_op(vm, op) + Instruction::BinaryOpSubscrGetitem => { + let owner = self.nth_value(1); + if !self.specialization_eval_frame_active(vm) + && let Some((func, func_version)) = + owner.class().get_cached_getitem_for_specialization() + && func.func_version() == func_version + && self.specialization_has_datastack_space_for_func(vm, &func) + { + debug_assert!(func.has_exact_argcount(2)); + let sub = self.pop_value(); + let owner = self.pop_value(); + let result = func.invoke_exact_args(vec![owner, sub], vm)?; + self.push_value(result); + return Ok(None); + } + self.execute_bin_op(vm, bytecode::BinaryOperator::Subscr) + } + Instruction::BinaryOpExtend => { + let op = self.binary_op_from_arg(arg); + let b = self.top_value(); + let a = self.nth_value(1); + let cache_base = self.lasti() as usize; + if let Some(descr) = self.read_cached_binary_op_extend_descr(cache_base) + && descr.oparg == op + && (descr.guard)(a, b, vm) + && let Some(result) = (descr.action)(a, b, vm) + { + self.pop_value(); + self.pop_value(); + self.push_value(result); + Ok(None) + } else { + self.execute_bin_op(vm, op) + } } Instruction::BinaryOpSubscrListInt => { let b = self.top_value(); @@ -3853,19 +4101,17 @@ impl ExecutingFrame<'_> { if let (Some(list), Some(idx)) = ( a.downcast_ref_if_exact::(vm), b.downcast_ref_if_exact::(vm), - ) && let Ok(i) = idx.try_to_primitive::(vm) + ) && let Some(i) = specialization_nonnegative_compact_index(idx, vm) { let vec = list.borrow_vec(); - if let Some(pos) = vec.wrap_index(i) { - let value = vec.do_get(pos); + if i < vec.len() { + let value = vec.do_get(i); drop(vec); self.pop_value(); self.pop_value(); self.push_value(value); return Ok(None); } - drop(vec); - return Err(vm.new_index_error("list index out of range")); } self.execute_bin_op(vm, bytecode::BinaryOperator::Subscr) } @@ -3875,17 +4121,16 @@ impl ExecutingFrame<'_> { if let (Some(tuple), Some(idx)) = ( a.downcast_ref_if_exact::(vm), b.downcast_ref_if_exact::(vm), - ) && let Ok(i) = idx.try_to_primitive::(vm) + ) && let Some(i) = specialization_nonnegative_compact_index(idx, vm) { let elements = tuple.as_slice(); - if let Some(pos) = elements.wrap_index(i) { - let value = elements[pos].clone(); + if i < elements.len() { + let value = elements[i].clone(); self.pop_value(); self.pop_value(); self.push_value(value); return Ok(None); } - return Err(vm.new_index_error("tuple index out of range")); } self.execute_bin_op(vm, bytecode::BinaryOperator::Subscr) } @@ -3918,19 +4163,15 @@ impl ExecutingFrame<'_> { if let (Some(a_str), Some(b_int)) = ( a.downcast_ref_if_exact::(vm), b.downcast_ref_if_exact::(vm), - ) && let Ok(i) = b_int.try_to_primitive::(vm) + ) && let Some(i) = specialization_nonnegative_compact_index(b_int, vm) + && let Ok(ch) = a_str.getitem_by_index(vm, i as isize) + && ch.is_ascii() { - match a_str.getitem_by_index(vm, i) { - Ok(ch) => { - self.pop_value(); - self.pop_value(); - self.push_value(PyStr::from(ch).into_pyobject(vm)); - return Ok(None); - } - Err(e) => { - return Err(e); - } - } + let ascii_idx = ch.to_u32() as usize; + self.pop_value(); + self.pop_value(); + self.push_value(vm.ctx.ascii_char_cache[ascii_idx].clone().into()); + return Ok(None); } self.execute_bin_op(vm, bytecode::BinaryOperator::Subscr) } @@ -3953,16 +4194,34 @@ impl ExecutingFrame<'_> { let cache_base = instr_idx + 1; let cached_version = self.code.instructions.read_cache_u32(cache_base + 1); let nargs: u32 = arg.into(); + if self.specialization_eval_frame_active(vm) { + return self.execute_call_vectorcall(nargs, vm); + } // Stack: [callable, self_or_null, arg1, ..., argN] + let stack_len = self.localsplus.stack_len(); + let self_or_null_is_some = self + .localsplus + .stack_index(stack_len - nargs as usize - 1) + .is_some(); let callable = self.nth_value(nargs + 1); - if let Some(func) = callable.downcast_ref::() + if let Some(func) = callable.downcast_ref_if_exact::(vm) && func.func_version() == cached_version && cached_version != 0 { + let effective_nargs = nargs + u32::from(self_or_null_is_some); + if !func.has_exact_argcount(effective_nargs) { + return self.execute_call_vectorcall(nargs, vm); + } + if !self.specialization_has_datastack_space_for_func(vm, func) { + return self.execute_call_vectorcall(nargs, vm); + } + if self.specialization_call_recursion_guard(vm) { + return self.execute_call_vectorcall(nargs, vm); + } let pos_args: Vec = self.pop_multiple(nargs as usize).collect(); let self_or_null = self.pop_value_opt(); let callable = self.pop_value(); - let func = callable.downcast_ref::().unwrap(); + let func = callable.downcast_ref_if_exact::(vm).unwrap(); let args = if let Some(self_val) = self_or_null { let mut all_args = Vec::with_capacity(pos_args.len() + 1); all_args.push(self_val); @@ -3975,8 +4234,7 @@ impl ExecutingFrame<'_> { self.push_value(result); Ok(None) } else { - let args = self.collect_positional_args(nargs); - self.execute_call(args, vm) + self.execute_call_vectorcall(nargs, vm) } } Instruction::CallBoundMethodExactArgs => { @@ -3984,6 +4242,9 @@ impl ExecutingFrame<'_> { let cache_base = instr_idx + 1; let cached_version = self.code.instructions.read_cache_u32(cache_base + 1); let nargs: u32 = arg.into(); + if self.specialization_eval_frame_active(vm) { + return self.execute_call_vectorcall(nargs, vm); + } // Stack: [callable, self_or_null(NULL), arg1, ..., argN] let stack_len = self.localsplus.stack_len(); let self_or_null_is_some = self @@ -3992,14 +4253,23 @@ impl ExecutingFrame<'_> { .is_some(); let callable = self.nth_value(nargs + 1); if !self_or_null_is_some - && let Some(bound_method) = callable.downcast_ref::() + && let Some(bound_method) = callable.downcast_ref_if_exact::(vm) { let bound_function = bound_method.function_obj().clone(); let bound_self = bound_method.self_obj().clone(); - if let Some(func) = bound_function.downcast_ref::() + if let Some(func) = bound_function.downcast_ref_if_exact::(vm) && func.func_version() == cached_version && cached_version != 0 { + if !func.has_exact_argcount(nargs + 1) { + return self.execute_call_vectorcall(nargs, vm); + } + if !self.specialization_has_datastack_space_for_func(vm, func) { + return self.execute_call_vectorcall(nargs, vm); + } + if self.specialization_call_recursion_guard(vm) { + return self.execute_call_vectorcall(nargs, vm); + } let pos_args: Vec = self.pop_multiple(nargs as usize).collect(); self.pop_value_opt(); // null (self_or_null) @@ -4011,28 +4281,23 @@ impl ExecutingFrame<'_> { self.push_value(result); return Ok(None); } - let args = self.collect_positional_args(nargs); - self.execute_call(args, vm) - } else { - let args = self.collect_positional_args(nargs); - self.execute_call(args, vm) } + self.execute_call_vectorcall(nargs, vm) } Instruction::CallLen => { - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; - let cached_tag = self.code.instructions.read_cache_u32(cache_base + 1); let nargs: u32 = arg.into(); if nargs == 1 { // Stack: [callable, null, arg] let obj = self.pop_value(); // arg let null = self.pop_value_opt(); let callable = self.pop_value(); - let callable_tag = &*callable as *const PyObject as u32; - let is_len_callable = callable - .downcast_ref_if_exact::(vm) - .is_some_and(|native| native.zelf.is_none() && native.value.name == "len"); - if null.is_none() && cached_tag == callable_tag && is_len_callable { + if null.is_none() + && vm + .callable_cache + .len + .as_ref() + .is_some_and(|len_callable| callable.is(len_callable)) + { let len = obj.length(vm)?; self.push_value(vm.ctx.new_int(len).into()); return Ok(None); @@ -4042,16 +4307,9 @@ impl ExecutingFrame<'_> { self.push_value_opt(null); self.push_value(obj); } - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); - let args = self.collect_positional_args(nargs); - self.execute_call(args, vm) + self.execute_call_vectorcall(nargs, vm) } Instruction::CallIsinstance => { - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; - let cached_tag = self.code.instructions.read_cache_u32(cache_base + 1); let nargs: u32 = arg.into(); let stack_len = self.localsplus.stack_len(); let self_or_null_is_some = self @@ -4061,13 +4319,12 @@ impl ExecutingFrame<'_> { let effective_nargs = nargs + u32::from(self_or_null_is_some); if effective_nargs == 2 { let callable = self.nth_value(nargs + 1); - let callable_tag = callable as *const PyObject as u32; - let is_isinstance_callable = callable - .downcast_ref_if_exact::(vm) - .is_some_and(|native| { - native.zelf.is_none() && native.value.name == "isinstance" - }); - if cached_tag == callable_tag && is_isinstance_callable { + if vm + .callable_cache + .isinstance + .as_ref() + .is_some_and(|isinstance_callable| callable.is(isinstance_callable)) + { let nargs_usize = nargs as usize; let pos_args: Vec = self.pop_multiple(nargs_usize).collect(); let self_or_null = self.pop_value_opt(); @@ -4082,11 +4339,7 @@ impl ExecutingFrame<'_> { return Ok(None); } } - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); - let args = self.collect_positional_args(nargs); - self.execute_call(args, vm) + self.execute_call_vectorcall(nargs, vm) } Instruction::CallType1 => { let nargs: u32 = arg.into(); @@ -4105,11 +4358,7 @@ impl ExecutingFrame<'_> { self.push_value_opt(null); self.push_value(obj); } - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); - let args = self.collect_positional_args(nargs); - self.execute_call(args, vm) + self.execute_call_vectorcall(nargs, vm) } Instruction::CallStr1 => { let nargs: u32 = arg.into(); @@ -4126,11 +4375,7 @@ impl ExecutingFrame<'_> { self.push_value_opt(null); self.push_value(obj); } - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); - let args = self.collect_positional_args(nargs); - self.execute_call(args, vm) + self.execute_call_vectorcall(nargs, vm) } Instruction::CallTuple1 => { let nargs: u32 = arg.into(); @@ -4152,11 +4397,7 @@ impl ExecutingFrame<'_> { self.push_value_opt(null); self.push_value(obj); } - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); - let args = self.collect_positional_args(nargs); - self.execute_call(args, vm) + self.execute_call_vectorcall(nargs, vm) } Instruction::CallBuiltinO => { let nargs: u32 = arg.into(); @@ -4167,30 +4408,30 @@ impl ExecutingFrame<'_> { .is_some(); let effective_nargs = nargs + u32::from(self_or_null_is_some); let callable = self.nth_value(nargs + 1); - if callable - .downcast_ref_if_exact::(vm) - .is_some() - && effective_nargs == 1 - { - let nargs_usize = nargs as usize; - let pos_args: Vec = self.pop_multiple(nargs_usize).collect(); - let self_or_null = self.pop_value_opt(); - let callable = self.pop_value(); - let mut args_vec = Vec::with_capacity(effective_nargs as usize); - if let Some(self_val) = self_or_null { - args_vec.push(self_val); + if let Some(native) = callable.downcast_ref_if_exact::(vm) { + let call_conv = native.value.flags + & (PyMethodFlags::VARARGS + | PyMethodFlags::FASTCALL + | PyMethodFlags::NOARGS + | PyMethodFlags::O + | PyMethodFlags::KEYWORDS); + if call_conv == PyMethodFlags::O && effective_nargs == 1 { + let nargs_usize = nargs as usize; + let pos_args: Vec = self.pop_multiple(nargs_usize).collect(); + let self_or_null = self.pop_value_opt(); + let callable = self.pop_value(); + let mut args_vec = Vec::with_capacity(effective_nargs as usize); + if let Some(self_val) = self_or_null { + args_vec.push(self_val); + } + args_vec.extend(pos_args); + let result = + callable.vectorcall(args_vec, effective_nargs as usize, None, vm)?; + self.push_value(result); + return Ok(None); } - args_vec.extend(pos_args); - let result = - callable.vectorcall(args_vec, effective_nargs as usize, None, vm)?; - self.push_value(result); - return Ok(None); } - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); - let args = self.collect_positional_args(nargs); - self.execute_call(args, vm) + self.execute_call_vectorcall(nargs, vm) } Instruction::CallBuiltinFast => { let nargs: u32 = arg.into(); @@ -4201,40 +4442,47 @@ impl ExecutingFrame<'_> { .is_some(); let effective_nargs = nargs + u32::from(self_or_null_is_some); let callable = self.nth_value(nargs + 1); - if callable - .downcast_ref_if_exact::(vm) - .is_some() - { - let nargs_usize = nargs as usize; - let pos_args: Vec = self.pop_multiple(nargs_usize).collect(); - let self_or_null = self.pop_value_opt(); - let callable = self.pop_value(); - let mut args_vec = Vec::with_capacity(effective_nargs as usize); - if let Some(self_val) = self_or_null { - args_vec.push(self_val); + if let Some(native) = callable.downcast_ref_if_exact::(vm) { + let call_conv = native.value.flags + & (PyMethodFlags::VARARGS + | PyMethodFlags::FASTCALL + | PyMethodFlags::NOARGS + | PyMethodFlags::O + | PyMethodFlags::KEYWORDS); + if call_conv == PyMethodFlags::FASTCALL { + let nargs_usize = nargs as usize; + let pos_args: Vec = self.pop_multiple(nargs_usize).collect(); + let self_or_null = self.pop_value_opt(); + let callable = self.pop_value(); + let mut args_vec = Vec::with_capacity(effective_nargs as usize); + if let Some(self_val) = self_or_null { + args_vec.push(self_val); + } + args_vec.extend(pos_args); + let result = + callable.vectorcall(args_vec, effective_nargs as usize, None, vm)?; + self.push_value(result); + return Ok(None); } - args_vec.extend(pos_args); - let result = - callable.vectorcall(args_vec, effective_nargs as usize, None, vm)?; - self.push_value(result); - return Ok(None); } - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); - let args = self.collect_positional_args(nargs); - self.execute_call(args, vm) + self.execute_call_vectorcall(nargs, vm) } Instruction::CallPyGeneral => { let instr_idx = self.lasti() as usize - 1; let cache_base = instr_idx + 1; let cached_version = self.code.instructions.read_cache_u32(cache_base + 1); let nargs: u32 = arg.into(); + if self.specialization_eval_frame_active(vm) { + return self.execute_call_vectorcall(nargs, vm); + } let callable = self.nth_value(nargs + 1); - if let Some(func) = callable.downcast_ref::() + if let Some(func) = callable.downcast_ref_if_exact::(vm) && func.func_version() == cached_version && cached_version != 0 { + if self.specialization_call_recursion_guard(vm) { + return self.execute_call_vectorcall(nargs, vm); + } let nargs_usize = nargs as usize; let pos_args: Vec = self.pop_multiple(nargs_usize).collect(); let self_or_null = self.pop_value_opt(); @@ -4252,11 +4500,7 @@ impl ExecutingFrame<'_> { self.push_value(result); Ok(None) } else { - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); - let args = self.collect_positional_args(nargs); - self.execute_call(args, vm) + self.execute_call_vectorcall(nargs, vm) } } Instruction::CallBoundMethodGeneral => { @@ -4264,6 +4508,9 @@ impl ExecutingFrame<'_> { let cache_base = instr_idx + 1; let cached_version = self.code.instructions.read_cache_u32(cache_base + 1); let nargs: u32 = arg.into(); + if self.specialization_eval_frame_active(vm) { + return self.execute_call_vectorcall(nargs, vm); + } let stack_len = self.localsplus.stack_len(); let self_or_null_is_some = self .localsplus @@ -4271,14 +4518,17 @@ impl ExecutingFrame<'_> { .is_some(); let callable = self.nth_value(nargs + 1); if !self_or_null_is_some - && let Some(bound_method) = callable.downcast_ref::() + && let Some(bound_method) = callable.downcast_ref_if_exact::(vm) { let bound_function = bound_method.function_obj().clone(); let bound_self = bound_method.self_obj().clone(); - if let Some(func) = bound_function.downcast_ref::() + if let Some(func) = bound_function.downcast_ref_if_exact::(vm) && func.func_version() == cached_version && cached_version != 0 { + if self.specialization_call_recursion_guard(vm) { + return self.execute_call_vectorcall(nargs, vm); + } let nargs_usize = nargs as usize; let pos_args: Vec = self.pop_multiple(nargs_usize).collect(); self.pop_value_opt(); // null (self_or_null) @@ -4296,15 +4546,8 @@ impl ExecutingFrame<'_> { self.push_value(result); return Ok(None); } - let args = self.collect_positional_args(nargs); - self.execute_call(args, vm) - } else { - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); - let args = self.collect_positional_args(nargs); - self.execute_call(args, vm) } + self.execute_call_vectorcall(nargs, vm) } Instruction::CallListAppend => { let nargs: u32 = arg.into(); @@ -4313,24 +4556,24 @@ impl ExecutingFrame<'_> { let stack_len = self.localsplus.stack_len(); let self_or_null_is_some = self.localsplus.stack_index(stack_len - 2).is_some(); let callable = self.nth_value(2); - let self_is_exact_list = self + let self_is_list = self .localsplus .stack_index(stack_len - 2) .as_ref() - .is_some_and(|obj| obj.class().is(vm.ctx.types.list_type)); - let is_list_append = - callable - .downcast_ref::() - .is_some_and(|descr| { - descr.method.name == "append" - && descr.objclass.is(vm.ctx.types.list_type) - }); - if is_list_append && self_or_null_is_some && self_is_exact_list { + .is_some_and(|obj| obj.downcast_ref::().is_some()); + if vm + .callable_cache + .list_append + .as_ref() + .is_some_and(|list_append| callable.is(list_append)) + && self_or_null_is_some + && self_is_list + { let item = self.pop_value(); let self_or_null = self.pop_value_opt(); let callable = self.pop_value(); if let Some(list_obj) = self_or_null.as_ref() - && let Some(list) = list_obj.downcast_ref_if_exact::(vm) + && let Some(list) = list_obj.downcast_ref::() { list.append(item); // CALL_LIST_APPEND fuses the following POP_TOP. @@ -4345,31 +4588,46 @@ impl ExecutingFrame<'_> { self.push_value(item); } } - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); - let args = self.collect_positional_args(nargs); - self.execute_call(args, vm) + self.execute_call_vectorcall(nargs, vm) } Instruction::CallMethodDescriptorNoargs => { let nargs: u32 = arg.into(); - if nargs == 0 { - // Stack: [callable, self_or_null] — peek to get func ptr - let stack_len = self.localsplus.stack_len(); - let self_or_null_is_some = self.localsplus.stack_index(stack_len - 1).is_some(); - let callable = self.nth_value(1); - let func = if self_or_null_is_some { - callable - .downcast_ref::() - .map(|d| d.method.func) - } else { - None - }; - if let Some(func) = func { - let self_val = self.pop_value_opt().unwrap(); + let stack_len = self.localsplus.stack_len(); + let self_or_null_is_some = self + .localsplus + .stack_index(stack_len - nargs as usize - 1) + .is_some(); + let total_nargs = nargs + u32::from(self_or_null_is_some); + if total_nargs == 1 { + let callable = self.nth_value(nargs + 1); + let self_index = + stack_len - nargs as usize - 1 + usize::from(!self_or_null_is_some); + if let Some(descr) = callable.downcast_ref_if_exact::(vm) + && (descr.method.flags + & (PyMethodFlags::VARARGS + | PyMethodFlags::FASTCALL + | PyMethodFlags::NOARGS + | PyMethodFlags::O + | PyMethodFlags::KEYWORDS)) + == PyMethodFlags::NOARGS + && self + .localsplus + .stack_index(self_index) + .as_ref() + .is_some_and(|self_obj| self_obj.class().is(descr.objclass)) + { + let func = descr.method.func; + let positional_args: Vec = + self.pop_multiple(nargs as usize).collect(); + let self_or_null = self.pop_value_opt(); self.pop_value(); // callable + let mut all_args = Vec::with_capacity(total_nargs as usize); + if let Some(self_val) = self_or_null { + all_args.push(self_val); + } + all_args.extend(positional_args); let args = FuncArgs { - args: vec![self_val], + args: all_args, kwargs: Default::default(), }; let result = func(vm, args)?; @@ -4377,31 +4635,46 @@ impl ExecutingFrame<'_> { return Ok(None); } } - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); self.execute_call_vectorcall(nargs, vm) } Instruction::CallMethodDescriptorO => { let nargs: u32 = arg.into(); - if nargs == 1 { - // Stack: [callable, self_or_null, arg1] - let stack_len = self.localsplus.stack_len(); - let self_or_null_is_some = self.localsplus.stack_index(stack_len - 2).is_some(); - let callable = self.nth_value(2); - let func = if self_or_null_is_some { - callable - .downcast_ref::() - .map(|d| d.method.func) - } else { - None - }; - if let Some(func) = func { - let obj = self.pop_value(); - let self_val = self.pop_value_opt().unwrap(); + let stack_len = self.localsplus.stack_len(); + let self_or_null_is_some = self + .localsplus + .stack_index(stack_len - nargs as usize - 1) + .is_some(); + let total_nargs = nargs + u32::from(self_or_null_is_some); + if total_nargs == 2 { + let callable = self.nth_value(nargs + 1); + let self_index = + stack_len - nargs as usize - 1 + usize::from(!self_or_null_is_some); + if let Some(descr) = callable.downcast_ref_if_exact::(vm) + && (descr.method.flags + & (PyMethodFlags::VARARGS + | PyMethodFlags::FASTCALL + | PyMethodFlags::NOARGS + | PyMethodFlags::O + | PyMethodFlags::KEYWORDS)) + == PyMethodFlags::O + && self + .localsplus + .stack_index(self_index) + .as_ref() + .is_some_and(|self_obj| self_obj.class().is(descr.objclass)) + { + let func = descr.method.func; + let positional_args: Vec = + self.pop_multiple(nargs as usize).collect(); + let self_or_null = self.pop_value_opt(); self.pop_value(); // callable + let mut all_args = Vec::with_capacity(total_nargs as usize); + if let Some(self_val) = self_or_null { + all_args.push(self_val); + } + all_args.extend(positional_args); let args = FuncArgs { - args: vec![self_val, obj], + args: all_args, kwargs: Default::default(), }; let result = func(vm, args)?; @@ -4409,33 +4682,43 @@ impl ExecutingFrame<'_> { return Ok(None); } } - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); self.execute_call_vectorcall(nargs, vm) } Instruction::CallMethodDescriptorFast => { let nargs: u32 = arg.into(); - let callable = self.nth_value(nargs + 1); let stack_len = self.localsplus.stack_len(); let self_or_null_is_some = self .localsplus .stack_index(stack_len - nargs as usize - 1) .is_some(); - let func = if self_or_null_is_some { - callable - .downcast_ref::() - .map(|d| d.method.func) - } else { - None - }; - if let Some(func) = func { + let total_nargs = nargs + u32::from(self_or_null_is_some); + let callable = self.nth_value(nargs + 1); + let self_index = + stack_len - nargs as usize - 1 + usize::from(!self_or_null_is_some); + if total_nargs > 0 + && let Some(descr) = callable.downcast_ref_if_exact::(vm) + && (descr.method.flags + & (PyMethodFlags::VARARGS + | PyMethodFlags::FASTCALL + | PyMethodFlags::NOARGS + | PyMethodFlags::O + | PyMethodFlags::KEYWORDS)) + == PyMethodFlags::FASTCALL + && self + .localsplus + .stack_index(self_index) + .as_ref() + .is_some_and(|self_obj| self_obj.class().is(descr.objclass)) + { + let func = descr.method.func; let positional_args: Vec = self.pop_multiple(nargs as usize).collect(); - let self_val = self.pop_value_opt().unwrap(); + let self_or_null = self.pop_value_opt(); self.pop_value(); // callable - let mut all_args = Vec::with_capacity(nargs as usize + 1); - all_args.push(self_val); + let mut all_args = Vec::with_capacity(total_nargs as usize); + if let Some(self_val) = self_or_null { + all_args.push(self_val); + } all_args.extend(positional_args); let args = FuncArgs { args: all_args, @@ -4445,9 +4728,6 @@ impl ExecutingFrame<'_> { self.push_value(result); return Ok(None); } - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); self.execute_call_vectorcall(nargs, vm) } Instruction::CallBuiltinClass => { @@ -4488,81 +4768,84 @@ impl ExecutingFrame<'_> { .localsplus .stack_index(stack_len - nargs as usize - 1) .is_some(); - if !self_or_null_is_some + if !self.specialization_eval_frame_active(vm) + && !self_or_null_is_some && cached_version != 0 && let Some(cls) = callable.downcast_ref::() && cls.tp_version_tag.load(Acquire) == cached_version + && let Some(init_func) = cls.get_cached_init_for_specialization(cached_version) + && let Some(cls_alloc) = cls.slots.alloc.load() { - // Look up __init__ (guarded by type_version) - if let Some(init) = cls.get_attr(identifier!(vm, __init__)) - && let Some(init_func) = init.downcast_ref::() - && init_func.can_specialize_call(nargs + 1) - { - // Allocate object directly (tp_new == object.__new__) - let dict = if cls - .slots - .flags - .has_feature(crate::types::PyTypeFlags::HAS_DICT) - { - Some(vm.ctx.new_dict()) - } else { - None - }; - let cls_ref = cls.to_owned(); - let new_obj: PyObjectRef = - PyRef::new_ref(PyBaseObject, cls_ref, dict).into(); - - // Build args: [new_obj, arg1, ..., argN] - let pos_args: Vec = - self.pop_multiple(nargs as usize).collect(); - let _null = self.pop_value_opt(); // self_or_null (None) - let _callable = self.pop_value(); // callable (type) - - let mut all_args = Vec::with_capacity(pos_args.len() + 1); - all_args.push(new_obj.clone()); - all_args.extend(pos_args); - - let init_result = init_func.invoke_exact_args(all_args, vm)?; - - // EXIT_INIT_CHECK: __init__ must return None - if !vm.is_none(&init_result) { - return Err( - vm.new_type_error("__init__() should return None".to_owned()) - ); - } - - self.push_value(new_obj); - return Ok(None); + // Match CPython's `code->co_framesize + _Py_InitCleanup.co_framesize` + // shape, using RustPython's datastack-backed frame size + // equivalent for the extra shim frame. + let init_cleanup_stack_bytes = + datastack_frame_size_bytes_for_code(&vm.ctx.init_cleanup_code) + .expect("_Py_InitCleanup shim is not a generator/coroutine"); + if !self.specialization_has_datastack_space_for_func_with_extra( + vm, + &init_func, + init_cleanup_stack_bytes, + ) { + return self.execute_call_vectorcall(nargs, vm); + } + // CPython creates `_Py_InitCleanup` + `__init__` frames here. + // Keep the guard conservative and deopt when the effective + // recursion budget for those two frames is not available. + if self.specialization_call_recursion_guard_with_extra_frames(vm, 1) { + return self.execute_call_vectorcall(nargs, vm); } + // Allocate object directly (tp_new == object.__new__, tp_alloc == generic). + let cls_ref = cls.to_owned(); + let new_obj = cls_alloc(cls_ref, 0, vm)?; + + // Build args: [new_obj, arg1, ..., argN] + let pos_args: Vec = self.pop_multiple(nargs as usize).collect(); + let _null = self.pop_value_opt(); // self_or_null (None) + let _callable = self.pop_value(); // callable (type) + let result = self + .specialization_run_init_cleanup_shim(new_obj, &init_func, pos_args, vm)?; + self.push_value(result); + return Ok(None); } - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); self.execute_call_vectorcall(nargs, vm) } Instruction::CallMethodDescriptorFastWithKeywords => { // Native function interface is uniform regardless of keyword support let nargs: u32 = arg.into(); - let callable = self.nth_value(nargs + 1); let stack_len = self.localsplus.stack_len(); let self_or_null_is_some = self .localsplus .stack_index(stack_len - nargs as usize - 1) .is_some(); - let func = if self_or_null_is_some { - callable - .downcast_ref::() - .map(|d| d.method.func) - } else { - None - }; - if let Some(func) = func { + let total_nargs = nargs + u32::from(self_or_null_is_some); + let callable = self.nth_value(nargs + 1); + let self_index = + stack_len - nargs as usize - 1 + usize::from(!self_or_null_is_some); + if total_nargs > 0 + && let Some(descr) = callable.downcast_ref_if_exact::(vm) + && (descr.method.flags + & (PyMethodFlags::VARARGS + | PyMethodFlags::FASTCALL + | PyMethodFlags::NOARGS + | PyMethodFlags::O + | PyMethodFlags::KEYWORDS)) + == (PyMethodFlags::FASTCALL | PyMethodFlags::KEYWORDS) + && self + .localsplus + .stack_index(self_index) + .as_ref() + .is_some_and(|self_obj| self_obj.class().is(descr.objclass)) + { + let func = descr.method.func; let positional_args: Vec = self.pop_multiple(nargs as usize).collect(); - let self_val = self.pop_value_opt().unwrap(); + let self_or_null = self.pop_value_opt(); self.pop_value(); // callable - let mut all_args = Vec::with_capacity(nargs as usize + 1); - all_args.push(self_val); + let mut all_args = Vec::with_capacity(total_nargs as usize); + if let Some(self_val) = self_or_null { + all_args.push(self_val); + } all_args.extend(positional_args); let args = FuncArgs { args: all_args, @@ -4572,9 +4855,6 @@ impl ExecutingFrame<'_> { self.push_value(result); return Ok(None); } - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); self.execute_call_vectorcall(nargs, vm) } Instruction::CallBuiltinFastWithKeywords => { @@ -4587,27 +4867,29 @@ impl ExecutingFrame<'_> { .is_some(); let effective_nargs = nargs + u32::from(self_or_null_is_some); let callable = self.nth_value(nargs + 1); - if callable - .downcast_ref_if_exact::(vm) - .is_some() - { - let nargs_usize = nargs as usize; - let pos_args: Vec = self.pop_multiple(nargs_usize).collect(); - let self_or_null = self.pop_value_opt(); - let callable = self.pop_value(); - let mut args_vec = Vec::with_capacity(effective_nargs as usize); - if let Some(self_val) = self_or_null { - args_vec.push(self_val); + if let Some(native) = callable.downcast_ref_if_exact::(vm) { + let call_conv = native.value.flags + & (PyMethodFlags::VARARGS + | PyMethodFlags::FASTCALL + | PyMethodFlags::NOARGS + | PyMethodFlags::O + | PyMethodFlags::KEYWORDS); + if call_conv == (PyMethodFlags::FASTCALL | PyMethodFlags::KEYWORDS) { + let nargs_usize = nargs as usize; + let pos_args: Vec = self.pop_multiple(nargs_usize).collect(); + let self_or_null = self.pop_value_opt(); + let callable = self.pop_value(); + let mut args_vec = Vec::with_capacity(effective_nargs as usize); + if let Some(self_val) = self_or_null { + args_vec.push(self_val); + } + args_vec.extend(pos_args); + let result = + callable.vectorcall(args_vec, effective_nargs as usize, None, vm)?; + self.push_value(result); + return Ok(None); } - args_vec.extend(pos_args); - let result = - callable.vectorcall(args_vec, effective_nargs as usize, None, vm)?; - self.push_value(result); - return Ok(None); } - self.deoptimize(Instruction::Call { - argc: Arg::marker(), - }); self.execute_call_vectorcall(nargs, vm) } Instruction::CallNonPyGeneral => { @@ -4618,11 +4900,12 @@ impl ExecutingFrame<'_> { .stack_index(stack_len - nargs as usize - 1) .is_some(); let callable = self.nth_value(nargs + 1); - if callable.downcast_ref::().is_some() - || callable.downcast_ref::().is_some() + if callable.downcast_ref_if_exact::(vm).is_some() + || callable + .downcast_ref_if_exact::(vm) + .is_some() { - let args = self.collect_positional_args(nargs); - return self.execute_call(args, vm); + return self.execute_call_vectorcall(nargs, vm); } let nargs_usize = nargs as usize; let pos_args: Vec = self.pop_multiple(nargs_usize).collect(); @@ -4648,12 +4931,18 @@ impl ExecutingFrame<'_> { let cache_base = instr_idx + 1; let cached_version = self.code.instructions.read_cache_u32(cache_base + 1); let nargs: u32 = arg.into(); + if self.specialization_eval_frame_active(vm) { + return self.execute_call_kw_vectorcall(nargs, vm); + } // Stack: [callable, self_or_null, arg1, ..., argN, kwarg_names] let callable = self.nth_value(nargs + 2); - if let Some(func) = callable.downcast_ref::() + if let Some(func) = callable.downcast_ref_if_exact::(vm) && func.func_version() == cached_version && cached_version != 0 { + if self.specialization_call_recursion_guard(vm) { + return self.execute_call_kw_vectorcall(nargs, vm); + } let nargs_usize = nargs as usize; let kwarg_names_obj = self.pop_value(); let kwarg_names_tuple = kwarg_names_obj @@ -4683,17 +4972,16 @@ impl ExecutingFrame<'_> { self.push_value(result); return Ok(None); } - self.deoptimize(Instruction::CallKw { - argc: Arg::marker(), - }); - let args = self.collect_keyword_args(nargs); - self.execute_call(args, vm) + self.execute_call_kw_vectorcall(nargs, vm) } Instruction::CallKwBoundMethod => { let instr_idx = self.lasti() as usize - 1; let cache_base = instr_idx + 1; let cached_version = self.code.instructions.read_cache_u32(cache_base + 1); let nargs: u32 = arg.into(); + if self.specialization_eval_frame_active(vm) { + return self.execute_call_kw_vectorcall(nargs, vm); + } // Stack: [callable, self_or_null, arg1, ..., argN, kwarg_names] let stack_len = self.localsplus.stack_len(); let self_or_null_is_some = self @@ -4702,11 +4990,11 @@ impl ExecutingFrame<'_> { .is_some(); let callable = self.nth_value(nargs + 2); if !self_or_null_is_some - && let Some(bound_method) = callable.downcast_ref::() + && let Some(bound_method) = callable.downcast_ref_if_exact::(vm) { let bound_function = bound_method.function_obj().clone(); let bound_self = bound_method.self_obj().clone(); - if let Some(func) = bound_function.downcast_ref::() + if let Some(func) = bound_function.downcast_ref_if_exact::(vm) && func.func_version() == cached_version && cached_version != 0 { @@ -4735,11 +5023,7 @@ impl ExecutingFrame<'_> { return Ok(None); } } - self.deoptimize(Instruction::CallKw { - argc: Arg::marker(), - }); - let args = self.collect_keyword_args(nargs); - self.execute_call(args, vm) + self.execute_call_kw_vectorcall(nargs, vm) } Instruction::CallKwNonPy => { let nargs: u32 = arg.into(); @@ -4749,11 +5033,12 @@ impl ExecutingFrame<'_> { .stack_index(stack_len - nargs as usize - 2) .is_some(); let callable = self.nth_value(nargs + 2); - if callable.downcast_ref::().is_some() - || callable.downcast_ref::().is_some() + if callable.downcast_ref_if_exact::(vm).is_some() + || callable + .downcast_ref_if_exact::(vm) + .is_some() { - let args = self.collect_keyword_args(nargs); - return self.execute_call(args, vm); + return self.execute_call_kw_vectorcall(nargs, vm); } let nargs_usize = nargs as usize; let kwarg_names_obj = self.pop_value(); @@ -4831,22 +5116,6 @@ impl ExecutingFrame<'_> { return Ok(None); } } - // Deoptimize - unsafe { - self.code.instructions.replace_op( - self.lasti() as usize - 1, - Instruction::LoadSuperAttr { - namei: Arg::marker(), - }, - ); - let cache_base = self.lasti() as usize; - self.code.instructions.write_adaptive_counter( - cache_base, - bytecode::adaptive_counter_backoff( - self.code.instructions.read_adaptive_counter(cache_base), - ), - ); - } let oparg = LoadSuperAttr::new(oparg); self.load_super_attr(vm, oparg) } @@ -4914,22 +5183,6 @@ impl ExecutingFrame<'_> { return Ok(None); } } - // Deoptimize - unsafe { - self.code.instructions.replace_op( - self.lasti() as usize - 1, - Instruction::LoadSuperAttr { - namei: Arg::marker(), - }, - ); - let cache_base = self.lasti() as usize; - self.code.instructions.write_adaptive_counter( - cache_base, - bytecode::adaptive_counter_backoff( - self.code.instructions.read_adaptive_counter(cache_base), - ), - ); - } let oparg = LoadSuperAttr::new(oparg); self.load_super_attr(vm, oparg) } @@ -4939,9 +5192,12 @@ impl ExecutingFrame<'_> { if let (Some(a_int), Some(b_int)) = ( a.downcast_ref_if_exact::(vm), b.downcast_ref_if_exact::(vm), + ) && let (Some(a_val), Some(b_val)) = ( + specialization_compact_int_value(a_int, vm), + specialization_compact_int_value(b_int, vm), ) { let op = self.compare_op_from_arg(arg); - let result = op.eval_ord(a_int.as_bigint().cmp(b_int.as_bigint())); + let result = op.eval_ord(a_val.cmp(&b_val)); self.pop_value(); self.pop_value(); self.push_value(vm.ctx.new_bool(result).into()); @@ -4984,6 +5240,11 @@ impl ExecutingFrame<'_> { b.downcast_ref_if_exact::(vm), ) { let op = self.compare_op_from_arg(arg); + if op != PyComparisonOp::Eq && op != PyComparisonOp::Ne { + let op = bytecode::ComparisonOperator::try_from(u32::from(arg)) + .unwrap_or(bytecode::ComparisonOperator::Equal); + return self.execute_compare(vm, op); + } let result = op.eval_ord(a_str.as_wtf8().cmp(b_str.as_wtf8())); self.pop_value(); self.pop_value(); @@ -5236,8 +5497,16 @@ impl ExecutingFrame<'_> { Instruction::ForIterGen => { let target = bytecode::Label(self.lasti() + 1 + u32::from(arg)); let iter = self.top_value(); + if self.specialization_eval_frame_active(vm) { + self.execute_for_iter(vm, target)?; + return Ok(None); + } if let Some(generator) = iter.downcast_ref_if_exact::(vm) { - match generator.as_coro().send(iter, vm.ctx.none(), vm) { + if generator.as_coro().running() || generator.as_coro().closed() { + self.execute_for_iter(vm, target)?; + return Ok(None); + } + match generator.as_coro().send_none(iter, vm) { Ok(PyIterReturn::Return(value)) => { self.push_value(value); } @@ -5259,8 +5528,7 @@ impl ExecutingFrame<'_> { Instruction::LoadGlobalModule => { let oparg = u32::from(arg); let cache_base = self.lasti() as usize; - // Keep specialized opcode on guard miss, matching CPython's - // JUMP_TO_PREDICTED(LOAD_GLOBAL) behavior. + // Keep specialized opcode on guard miss (JUMP_TO_PREDICTED behavior). let cached_version = self.code.instructions.read_cache_u16(cache_base + 1); let cached_index = self.code.instructions.read_cache_u16(cache_base + 3); if let Ok(current_version) = u16::try_from(self.globals.version()) @@ -5331,6 +5599,18 @@ impl ExecutingFrame<'_> { instruction.is_instrumented(), "execute_instrumented called with non-instrumented opcode {instruction:?}" ); + if self.monitoring_disabled_for_code(vm) { + let global_ver = vm + .state + .instrumentation_version + .load(atomic::Ordering::Acquire); + monitoring::instrument_code(self.code, 0); + self.code + .instrumentation_version + .store(global_ver, atomic::Ordering::Release); + self.update_lasti(|i| *i -= 1); + return Ok(None); + } self.monitoring_mask = vm.state.monitoring_events.load(); match instruction { Instruction::InstrumentedResume => { @@ -6291,7 +6571,7 @@ impl ExecutingFrame<'_> { args }; - let is_python_call = callable.downcast_ref::().is_some(); + let is_python_call = callable.downcast_ref_if_exact::(vm).is_some(); // Fire CALL event let call_arg0 = if self.monitoring_mask & monitoring::EVENT_CALL != 0 { @@ -6360,9 +6640,8 @@ impl ExecutingFrame<'_> { bytecode::RaiseKind::BareRaise => { // RAISE_VARARGS 0: bare `raise` gets exception from VM state // This is the current exception set by PUSH_EXC_INFO - vm.topmost_exception().ok_or_else(|| { - vm.new_runtime_error("No active exception to reraise".to_owned()) - })? + vm.topmost_exception() + .ok_or_else(|| vm.new_runtime_error("No active exception to reraise"))? } bytecode::RaiseKind::ReraiseFromStack => { // RERAISE: gets exception from stack top @@ -6582,7 +6861,7 @@ impl ExecutingFrame<'_> { fn execute_set_function_attribute( &mut self, vm: &VirtualMachine, - attr: bytecode::MakeFunctionFlags, + attr: bytecode::MakeFunctionFlag, ) -> FrameResult { // SET_FUNCTION_ATTRIBUTE sets attributes on a function // Stack: [..., attr_value, func] -> [..., func] @@ -6594,7 +6873,7 @@ impl ExecutingFrame<'_> { let func = self.top_value(); // Get the function reference and call the new method let func_ref = func - .downcast_ref::() + .downcast_ref_if_exact::(vm) .expect("SET_FUNCTION_ATTRIBUTE expects function on stack"); let payload: &PyFunction = func_ref.payload(); @@ -6886,6 +7165,14 @@ impl ExecutingFrame<'_> { Ok(None) } + /// Read a cached descriptor pointer and validate it against the expected + /// type version, using a lock-free double-check pattern: + /// 1. read pointer → incref (try_to_owned) + /// 2. re-read version + pointer and confirm they still match + /// + /// This matches the read-side pattern used in LOAD_ATTR_METHOD_WITH_VALUES + /// and friends: no read-side lock, relying on the write side to invalidate + /// the version tag before swapping the pointer. #[inline] fn try_read_cached_descriptor( &self, @@ -6896,7 +7183,12 @@ impl ExecutingFrame<'_> { if descr_ptr == 0 { return None; } + // SAFETY: `descr_ptr` was a valid `*mut PyObject` when the writer + // stored it, and the writer keeps a strong reference alive in + // `InlineCacheEntry`. `try_to_owned_from_ptr` performs a + // conditional incref that fails if the object is already freed. let cloned = unsafe { PyObject::try_to_owned_from_ptr(descr_ptr as *mut PyObject) }?; + // Double-check: version tag still matches AND pointer unchanged. if self.code.instructions.read_cache_u32(cache_base + 1) == expected_type_version && self.code.instructions.read_cache_ptr(cache_base + 5) == descr_ptr { @@ -6914,8 +7206,9 @@ impl ExecutingFrame<'_> { type_version: u32, descr_ptr: usize, ) { - // Publish descriptor cache atomically as a tuple: + // Publish descriptor cache with version-invalidation protocol: // invalidate version first, then write payload, then publish version. + // Reader double-checks version+ptr after incref, so no writer lock needed. unsafe { self.code.instructions.write_cache_u32(cache_base + 1, 0); self.code @@ -6935,7 +7228,6 @@ impl ExecutingFrame<'_> { metaclass_version: u32, descr_ptr: usize, ) { - // Same publish protocol as write_cached_descriptor(), plus metaclass guard. unsafe { self.code.instructions.write_cache_u32(cache_base + 1, 0); self.code @@ -6950,6 +7242,51 @@ impl ExecutingFrame<'_> { } } + #[inline] + unsafe fn write_cached_binary_op_extend_descr( + &self, + cache_base: usize, + descr: Option<&'static BinaryOpExtendSpecializationDescr>, + ) { + let ptr = descr.map_or(0, |d| { + d as *const BinaryOpExtendSpecializationDescr as usize + }); + unsafe { + self.code + .instructions + .write_cache_ptr(cache_base + BINARY_OP_EXTEND_EXTERNAL_CACHE_OFFSET, ptr); + } + } + + #[inline] + fn read_cached_binary_op_extend_descr( + &self, + cache_base: usize, + ) -> Option<&'static BinaryOpExtendSpecializationDescr> { + let ptr = self + .code + .instructions + .read_cache_ptr(cache_base + BINARY_OP_EXTEND_EXTERNAL_CACHE_OFFSET); + if ptr == 0 { + return None; + } + // SAFETY: We only store pointers to entries in `BINARY_OP_EXTEND_DESCRIPTORS`. + Some(unsafe { &*(ptr as *const BinaryOpExtendSpecializationDescr) }) + } + + #[inline] + fn binary_op_extended_specialization( + &self, + op: bytecode::BinaryOperator, + lhs: &PyObject, + rhs: &PyObject, + vm: &VirtualMachine, + ) -> Option<&'static BinaryOpExtendSpecializationDescr> { + BINARY_OP_EXTEND_DESCRIPTORS + .iter() + .find(|d| d.oparg == op && (d.guard)(lhs, rhs, vm)) + } + fn load_attr(&mut self, vm: &VirtualMachine, oparg: LoadAttr) -> FrameResult { self.adaptive(|s, ii, cb| s.specialize_load_attr(vm, oparg, ii, cb)); self.load_attr_slow(vm, oparg) @@ -6985,6 +7322,35 @@ impl ExecutingFrame<'_> { .load() .is_some_and(|f| f as usize == PyBaseObject::getattro as *const () as usize); if !is_default_getattro { + let mut type_version = cls.tp_version_tag.load(Acquire); + if type_version == 0 { + type_version = cls.assign_version_tag(); + } + if type_version != 0 + && !oparg.is_method() + && !self.specialization_eval_frame_active(_vm) + && cls.get_attr(identifier!(_vm, __getattr__)).is_none() + && let Some(getattribute) = cls.get_attr(identifier!(_vm, __getattribute__)) + && let Some(func) = getattribute.downcast_ref_if_exact::(_vm) + && func.can_specialize_call(2) + { + let func_version = func.get_version_for_current_state(); + if func_version != 0 { + let func_ptr = &*getattribute as *const PyObject as usize; + unsafe { + self.code + .instructions + .write_cache_u32(cache_base + 3, func_version); + self.write_cached_descriptor(cache_base, type_version, func_ptr); + } + self.specialize_at( + instr_idx, + cache_base, + Instruction::LoadAttrGetattributeOverridden, + ); + return; + } + } unsafe { self.code.instructions.write_adaptive_counter( cache_base, @@ -7014,19 +7380,38 @@ impl ExecutingFrame<'_> { return; } - // Module attribute access: use LoadAttrModule - if obj.downcast_ref_if_exact::(_vm).is_some() { - unsafe { - self.code - .instructions - .write_cache_u32(cache_base + 1, type_version); + let attr_name = self.code.names[oparg.name_idx() as usize]; + + // Match CPython: only specialize module attribute loads when the + // current module dict has no __getattr__ override and the attribute is + // already present. + if let Some(module) = obj.downcast_ref_if_exact::(_vm) { + let module_dict = module.dict(); + match ( + module_dict.get_item_opt(identifier!(_vm, __getattr__), _vm), + module_dict.get_item_opt(attr_name, _vm), + ) { + (Ok(None), Ok(Some(_))) => { + unsafe { + self.code + .instructions + .write_cache_u32(cache_base + 1, type_version); + } + self.specialize_at(instr_idx, cache_base, Instruction::LoadAttrModule); + } + (Ok(_), Ok(_)) => self.cooldown_adaptive_at(cache_base), + _ => unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + }, } - self.specialize_at(instr_idx, cache_base, Instruction::LoadAttrModule); return; } - let attr_name = self.code.names[oparg.name_idx() as usize]; - // Look up attr in class via MRO let cls_attr = cls.get_attr(attr_name); let class_has_dict = cls.slots.flags.has_feature(PyTypeFlags::HAS_DICT); @@ -7091,12 +7476,16 @@ impl ExecutingFrame<'_> { } self.specialize_at(instr_idx, cache_base, Instruction::LoadAttrSlot); } else if let Some(ref descr) = cls_attr - && descr.downcast_ref::().is_some() + && let Some(prop) = descr.downcast_ref::() + && let Some(fget) = prop.get_fget() + && let Some(func) = fget.downcast_ref_if_exact::(_vm) + && func.can_specialize_call(1) + && !self.specialization_eval_frame_active(_vm) { - // Property descriptor — cache the property object pointer - let descr_ptr = &**descr as *const PyObject as usize; + // Property specialization caches fget directly. + let fget_ptr = &*fget as *const PyObject as usize; unsafe { - self.write_cached_descriptor(cache_base, type_version, descr_ptr); + self.write_cached_descriptor(cache_base, type_version, fget_ptr); } self.specialize_at(instr_idx, cache_base, Instruction::LoadAttrProperty); } else { @@ -7132,8 +7521,11 @@ impl ExecutingFrame<'_> { Instruction::LoadAttrNondescriptorWithValues, ); } else { - // No class attr, must be in instance dict - let use_hint = if let Some(dict) = obj.dict() { + // Match CPython ABSENT/no-shadow behavior: if the + // attribute is missing on both the class and the current + // instance, keep the generic opcode and just enter + // cooldown instead of specializing a repeated miss path. + let has_instance_attr = if let Some(dict) = obj.dict() { match dict.get_item_opt(attr_name, _vm) { Ok(Some(_)) => true, Ok(None) => false, @@ -7154,20 +7546,16 @@ impl ExecutingFrame<'_> { } else { false }; - unsafe { - self.code - .instructions - .write_cache_u32(cache_base + 1, type_version); + if has_instance_attr { + unsafe { + self.code + .instructions + .write_cache_u32(cache_base + 1, type_version); + } + self.specialize_at(instr_idx, cache_base, Instruction::LoadAttrWithHint); + } else { + self.cooldown_adaptive_at(cache_base); } - self.specialize_at( - instr_idx, - cache_base, - if use_hint { - Instruction::LoadAttrWithHint - } else { - Instruction::LoadAttrInstanceValue - }, - ); } } else if let Some(ref descr) = cls_attr { // No dict support, plain class attr — cache directly @@ -7181,15 +7569,8 @@ impl ExecutingFrame<'_> { Instruction::LoadAttrNondescriptorNoDict, ); } else { - // No dict, no class attr — can't specialize - unsafe { - self.code.instructions.write_adaptive_counter( - cache_base, - bytecode::adaptive_counter_backoff( - self.code.instructions.read_adaptive_counter(cache_base), - ), - ); - } + // No dict and no class attr: repeated miss path, so cooldown. + self.cooldown_adaptive_at(cache_base); } } } @@ -7341,6 +7722,11 @@ impl ExecutingFrame<'_> { } let b = self.top_value(); let a = self.nth_value(1); + // `external_cache` in _PyBinaryOpCache is used only by BINARY_OP_EXTEND. + unsafe { + self.write_cached_binary_op_extend_descr(cache_base, None); + } + let mut cached_extend_descr = None; let new_op = match op { bytecode::BinaryOperator::Add => { @@ -7355,7 +7741,17 @@ impl ExecutingFrame<'_> { } else if a.downcast_ref_if_exact::(vm).is_some() && b.downcast_ref_if_exact::(vm).is_some() { - Some(Instruction::BinaryOpAddUnicode) + if self + .binary_op_inplace_unicode_target_local(cache_base, a) + .is_some() + { + Some(Instruction::BinaryOpInplaceAddUnicode) + } else { + Some(Instruction::BinaryOpAddUnicode) + } + } else if let Some(descr) = self.binary_op_extended_specialization(op, a, b, vm) { + cached_extend_descr = Some(descr); + Some(Instruction::BinaryOpExtend) } else { None } @@ -7369,6 +7765,9 @@ impl ExecutingFrame<'_> { && b.downcast_ref_if_exact::(vm).is_some() { Some(Instruction::BinaryOpSubtractFloat) + } else if let Some(descr) = self.binary_op_extended_specialization(op, a, b, vm) { + cached_extend_descr = Some(descr); + Some(Instruction::BinaryOpExtend) } else { None } @@ -7382,38 +7781,126 @@ impl ExecutingFrame<'_> { && b.downcast_ref_if_exact::(vm).is_some() { Some(Instruction::BinaryOpMultiplyFloat) + } else if let Some(descr) = self.binary_op_extended_specialization(op, a, b, vm) { + cached_extend_descr = Some(descr); + Some(Instruction::BinaryOpExtend) + } else { + None + } + } + bytecode::BinaryOperator::TrueDivide => { + if let Some(descr) = self.binary_op_extended_specialization(op, a, b, vm) { + cached_extend_descr = Some(descr); + Some(Instruction::BinaryOpExtend) } else { None } } bytecode::BinaryOperator::Subscr => { - if a.downcast_ref_if_exact::(vm).is_some() - && b.downcast_ref_if_exact::(vm).is_some() - { + let b_is_nonnegative_int = b + .downcast_ref_if_exact::(vm) + .is_some_and(|i| specialization_nonnegative_compact_index(i, vm).is_some()); + if a.downcast_ref_if_exact::(vm).is_some() && b_is_nonnegative_int { Some(Instruction::BinaryOpSubscrListInt) - } else if a.downcast_ref_if_exact::(vm).is_some() - && b.downcast_ref_if_exact::(vm).is_some() - { + } else if a.downcast_ref_if_exact::(vm).is_some() && b_is_nonnegative_int { Some(Instruction::BinaryOpSubscrTupleInt) } else if a.downcast_ref_if_exact::(vm).is_some() { Some(Instruction::BinaryOpSubscrDict) - } else if a.downcast_ref_if_exact::(vm).is_some() - && b.downcast_ref_if_exact::(vm).is_some() - { + } else if a.downcast_ref_if_exact::(vm).is_some() && b_is_nonnegative_int { Some(Instruction::BinaryOpSubscrStrInt) } else if a.downcast_ref_if_exact::(vm).is_some() && b.downcast_ref::().is_some() { Some(Instruction::BinaryOpSubscrListSlice) } else { - None + let cls = a.class(); + if cls.slots.flags.has_feature(PyTypeFlags::HEAPTYPE) + && !self.specialization_eval_frame_active(vm) + && let Some(_getitem) = cls.get_attr(identifier!(vm, __getitem__)) + && let Some(func) = _getitem.downcast_ref_if_exact::(vm) + && func.can_specialize_call(2) + { + let mut type_version = cls.tp_version_tag.load(Acquire); + if type_version == 0 { + type_version = cls.assign_version_tag(); + } + if type_version != 0 { + if cls.cache_getitem_for_specialization( + func.to_owned(), + type_version, + vm, + ) { + Some(Instruction::BinaryOpSubscrGetitem) + } else { + None + } + } else { + None + } + } else { + None + } } } bytecode::BinaryOperator::InplaceAdd => { if a.downcast_ref_if_exact::(vm).is_some() && b.downcast_ref_if_exact::(vm).is_some() { - Some(Instruction::BinaryOpInplaceAddUnicode) + if self + .binary_op_inplace_unicode_target_local(cache_base, a) + .is_some() + { + Some(Instruction::BinaryOpInplaceAddUnicode) + } else { + Some(Instruction::BinaryOpAddUnicode) + } + } else if a.downcast_ref_if_exact::(vm).is_some() + && b.downcast_ref_if_exact::(vm).is_some() + { + Some(Instruction::BinaryOpAddInt) + } else if a.downcast_ref_if_exact::(vm).is_some() + && b.downcast_ref_if_exact::(vm).is_some() + { + Some(Instruction::BinaryOpAddFloat) + } else { + None + } + } + bytecode::BinaryOperator::InplaceSubtract => { + if a.downcast_ref_if_exact::(vm).is_some() + && b.downcast_ref_if_exact::(vm).is_some() + { + Some(Instruction::BinaryOpSubtractInt) + } else if a.downcast_ref_if_exact::(vm).is_some() + && b.downcast_ref_if_exact::(vm).is_some() + { + Some(Instruction::BinaryOpSubtractFloat) + } else { + None + } + } + bytecode::BinaryOperator::InplaceMultiply => { + if a.downcast_ref_if_exact::(vm).is_some() + && b.downcast_ref_if_exact::(vm).is_some() + { + Some(Instruction::BinaryOpMultiplyInt) + } else if a.downcast_ref_if_exact::(vm).is_some() + && b.downcast_ref_if_exact::(vm).is_some() + { + Some(Instruction::BinaryOpMultiplyFloat) + } else { + None + } + } + bytecode::BinaryOperator::And + | bytecode::BinaryOperator::Or + | bytecode::BinaryOperator::Xor + | bytecode::BinaryOperator::InplaceAnd + | bytecode::BinaryOperator::InplaceOr + | bytecode::BinaryOperator::InplaceXor => { + if let Some(descr) = self.binary_op_extended_specialization(op, a, b, vm) { + cached_extend_descr = Some(descr); + Some(Instruction::BinaryOpExtend) } else { None } @@ -7421,9 +7908,35 @@ impl ExecutingFrame<'_> { _ => None, }; + if matches!(new_op, Some(Instruction::BinaryOpExtend)) { + unsafe { + self.write_cached_binary_op_extend_descr(cache_base, cached_extend_descr); + } + } self.commit_specialization(instr_idx, cache_base, new_op); } + #[inline] + fn binary_op_inplace_unicode_target_local( + &self, + cache_base: usize, + left: &PyObject, + ) -> Option { + let next_idx = cache_base + Instruction::BinaryOp { op: Arg::marker() }.cache_entries(); + let unit = self.code.instructions.get(next_idx)?; + let next_op = unit.op.to_base().unwrap_or(unit.op); + if !matches!(next_op, Instruction::StoreFast { .. }) { + return None; + } + let local_idx = usize::from(u8::from(unit.arg)); + self.localsplus + .fastlocals() + .get(local_idx) + .and_then(|slot| slot.as_ref()) + .filter(|local| local.is(left)) + .map(|_| local_idx) + } + /// Adaptive counter: trigger specialization at zero, otherwise advance countdown. #[inline] fn adaptive(&mut self, specialize: impl FnOnce(&mut Self, usize, usize)) { @@ -7453,6 +7966,15 @@ impl ExecutingFrame<'_> { } } + #[inline] + fn cooldown_adaptive_at(&mut self, cache_base: usize) { + unsafe { + self.code + .instructions + .write_adaptive_counter(cache_base, ADAPTIVE_COOLDOWN_VALUE); + } + } + /// Commit a specialization result: replace op on success, backoff on failure. #[inline] fn commit_specialization( @@ -7475,30 +7997,6 @@ impl ExecutingFrame<'_> { } } - /// Deoptimize: replace specialized op with its base adaptive op and reset - /// the adaptive counter. Computes instr_idx/cache_base from lasti(). - #[inline] - fn deoptimize(&mut self, base_op: Instruction) { - let instr_idx = self.lasti() as usize - 1; - let cache_base = instr_idx + 1; - self.deoptimize_at(base_op, instr_idx, cache_base); - } - - /// Deoptimize with explicit indices (for specialized handlers that already - /// have instr_idx/cache_base in scope). - #[inline] - fn deoptimize_at(&mut self, base_op: Instruction, instr_idx: usize, cache_base: usize) { - unsafe { - self.code.instructions.replace_op(instr_idx, base_op); - self.code.instructions.write_adaptive_counter( - cache_base, - bytecode::adaptive_counter_backoff( - self.code.instructions.read_adaptive_counter(cache_base), - ), - ); - } - } - /// Execute a specialized binary op on two int operands. /// Fallback to generic binary op if either operand is not an exact int. #[inline] @@ -7572,47 +8070,119 @@ impl ExecutingFrame<'_> { .is_some(); let callable = self.nth_value(nargs + 1); - if let Some(func) = callable.downcast_ref::() { - let version = func.get_version_for_current_state(); - if version == 0 { + if let Some(func) = callable.downcast_ref_if_exact::(vm) { + if self.specialization_eval_frame_active(vm) { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } + if !func.is_optimized_for_call_specialization() { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } + let version = func.get_version_for_current_state(); + if version == 0 { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } + + let effective_nargs = if self_or_null_is_some { + nargs + 1 + } else { + nargs + }; + + let new_op = if func.can_specialize_call(effective_nargs) { + Instruction::CallPyExactArgs + } else { + Instruction::CallPyGeneral + }; + unsafe { + self.code + .instructions + .write_cache_u32(cache_base + 1, version); + } + self.specialize_at(instr_idx, cache_base, new_op); + return; + } + + // Bound Python method object (`method`) specialization. + if !self_or_null_is_some + && let Some(bound_method) = callable.downcast_ref_if_exact::(vm) + { + if let Some(func) = bound_method + .function_obj() + .downcast_ref_if_exact::(vm) + { + if self.specialization_eval_frame_active(vm) { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } + if !func.is_optimized_for_call_specialization() { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } + let version = func.get_version_for_current_state(); + if version == 0 { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } + + let new_op = if func.can_specialize_call(nargs + 1) { + Instruction::CallBoundMethodExactArgs + } else { + Instruction::CallBoundMethodGeneral + }; unsafe { - self.code.instructions.write_adaptive_counter( - cache_base, - bytecode::adaptive_counter_backoff( - self.code.instructions.read_adaptive_counter(cache_base), - ), - ); + self.code + .instructions + .write_cache_u32(cache_base + 1, version); } - return; - } - - let effective_nargs = if self_or_null_is_some { - nargs + 1 - } else { - nargs - }; - - let new_op = if func.can_specialize_call(effective_nargs) { - Instruction::CallPyExactArgs + self.specialize_at(instr_idx, cache_base, new_op); } else { - Instruction::CallPyGeneral - }; - unsafe { - self.code - .instructions - .write_cache_u32(cache_base + 1, version); - } - self.specialize_at(instr_idx, cache_base, new_op); - return; - } - - // Bound Python method object (`method`) specialization. - if !self_or_null_is_some - && let Some(bound_method) = callable.downcast_ref::() - && let Some(func) = bound_method.function_obj().downcast_ref::() - { - let version = func.get_version_for_current_state(); - if version == 0 { + // Match CPython: bound methods wrapping non-Python callables + // are not specialized as CALL_NON_PY_GENERAL. unsafe { self.code.instructions.write_adaptive_counter( cache_base, @@ -7621,25 +8191,12 @@ impl ExecutingFrame<'_> { ), ); } - return; - } - - let new_op = if func.can_specialize_call(nargs + 1) { - Instruction::CallBoundMethodExactArgs - } else { - Instruction::CallBoundMethodGeneral - }; - unsafe { - self.code - .instructions - .write_cache_u32(cache_base + 1, version); } - self.specialize_at(instr_idx, cache_base, new_op); return; } // Try to specialize method descriptor calls - if self_or_null_is_some && let Some(descr) = callable.downcast_ref::() { + if let Some(descr) = callable.downcast_ref_if_exact::(vm) { let call_cache_entries = Instruction::CallListAppend.cache_entries(); let next_idx = cache_base + call_cache_entries; let next_is_pop_top = if next_idx < self.code.instructions.len() { @@ -7649,18 +8206,58 @@ impl ExecutingFrame<'_> { false }; - let new_op = if nargs == 1 - && descr.method.name == "append" - && descr.objclass.is(vm.ctx.types.list_type) - && next_is_pop_top - { - Instruction::CallListAppend - } else { - match nargs { - 0 => Instruction::CallMethodDescriptorNoargs, - 1 => Instruction::CallMethodDescriptorO, - _ => Instruction::CallMethodDescriptorFast, + let call_conv = descr.method.flags + & (PyMethodFlags::VARARGS + | PyMethodFlags::FASTCALL + | PyMethodFlags::NOARGS + | PyMethodFlags::O + | PyMethodFlags::KEYWORDS); + let total_nargs = nargs + u32::from(self_or_null_is_some); + + let new_op = if call_conv == PyMethodFlags::NOARGS { + if total_nargs != 1 { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } + Instruction::CallMethodDescriptorNoargs + } else if call_conv == PyMethodFlags::O { + if total_nargs != 2 { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } + if self_or_null_is_some + && nargs == 1 + && next_is_pop_top + && vm + .callable_cache + .list_append + .as_ref() + .is_some_and(|list_append| callable.is(list_append)) + { + Instruction::CallListAppend + } else { + Instruction::CallMethodDescriptorO } + } else if call_conv == PyMethodFlags::FASTCALL { + Instruction::CallMethodDescriptorFast + } else if call_conv == (PyMethodFlags::FASTCALL | PyMethodFlags::KEYWORDS) { + Instruction::CallMethodDescriptorFastWithKeywords + } else { + Instruction::CallNonPyGeneral }; self.specialize_at(instr_idx, cache_base, new_op); return; @@ -7669,71 +8266,123 @@ impl ExecutingFrame<'_> { // Try to specialize builtin calls if let Some(native) = callable.downcast_ref_if_exact::(vm) { let effective_nargs = nargs + u32::from(self_or_null_is_some); - let callable_tag = callable as *const PyObject as u32; - let new_op = if native.zelf.is_none() - && native.value.name == "len" - && nargs == 1 - && effective_nargs == 1 - { - Instruction::CallLen - } else if native.zelf.is_none() - && native.value.name == "isinstance" - && effective_nargs == 2 - { - Instruction::CallIsinstance - } else if effective_nargs == 1 { - Instruction::CallBuiltinO + let call_conv = native.value.flags + & (PyMethodFlags::VARARGS + | PyMethodFlags::FASTCALL + | PyMethodFlags::NOARGS + | PyMethodFlags::O + | PyMethodFlags::KEYWORDS); + let new_op = if call_conv == PyMethodFlags::O { + if effective_nargs != 1 { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } + if native.zelf.is_none() + && nargs == 1 + && vm + .callable_cache + .len + .as_ref() + .is_some_and(|len_callable| callable.is(len_callable)) + { + Instruction::CallLen + } else { + Instruction::CallBuiltinO + } + } else if call_conv == PyMethodFlags::FASTCALL { + if native.zelf.is_none() + && effective_nargs == 2 + && vm + .callable_cache + .isinstance + .as_ref() + .is_some_and(|isinstance_callable| callable.is(isinstance_callable)) + { + Instruction::CallIsinstance + } else { + Instruction::CallBuiltinFast + } + } else if call_conv == (PyMethodFlags::FASTCALL | PyMethodFlags::KEYWORDS) { + Instruction::CallBuiltinFastWithKeywords } else { - Instruction::CallBuiltinFast + Instruction::CallNonPyGeneral }; - if matches!(new_op, Instruction::CallLen | Instruction::CallIsinstance) { - unsafe { - self.code - .instructions - .write_cache_u32(cache_base + 1, callable_tag); - } - } self.specialize_at(instr_idx, cache_base, new_op); return; } // type/str/tuple(x) and class-call specializations - if callable.class().is(vm.ctx.types.type_type) - && let Some(cls) = callable.downcast_ref::() - { - if !self_or_null_is_some && nargs == 1 { - let new_op = if callable.is(&vm.ctx.types.type_type.as_object()) { - Some(Instruction::CallType1) - } else if callable.is(&vm.ctx.types.str_type.as_object()) { - Some(Instruction::CallStr1) - } else if callable.is(&vm.ctx.types.tuple_type.as_object()) { - Some(Instruction::CallTuple1) - } else { - None - }; - if let Some(new_op) = new_op { - self.specialize_at(instr_idx, cache_base, new_op); + if let Some(cls) = callable.downcast_ref::() { + if cls.slots.flags.has_feature(PyTypeFlags::IMMUTABLETYPE) { + if !self_or_null_is_some && nargs == 1 { + let new_op = if callable.is(&vm.ctx.types.type_type.as_object()) { + Some(Instruction::CallType1) + } else if callable.is(&vm.ctx.types.str_type.as_object()) { + Some(Instruction::CallStr1) + } else if callable.is(&vm.ctx.types.tuple_type.as_object()) { + Some(Instruction::CallTuple1) + } else { + None + }; + if let Some(new_op) = new_op { + self.specialize_at(instr_idx, cache_base, new_op); + return; + } + } + if cls.slots.vectorcall.load().is_some() { + self.specialize_at(instr_idx, cache_base, Instruction::CallBuiltinClass); return; } + self.specialize_at(instr_idx, cache_base, Instruction::CallNonPyGeneral); + return; } - if cls.slots.flags.has_feature(PyTypeFlags::IMMUTABLETYPE) - && cls.slots.vectorcall.load().is_some() - { - self.specialize_at(instr_idx, cache_base, Instruction::CallBuiltinClass); + + // CPython only considers CALL_ALLOC_AND_ENTER_INIT for types whose + // metaclass is exactly `type`. + if !callable.class().is(vm.ctx.types.type_type) { + self.specialize_at(instr_idx, cache_base, Instruction::CallNonPyGeneral); return; } + // CallAllocAndEnterInit: heap type with default __new__ if !self_or_null_is_some && cls.slots.flags.has_feature(PyTypeFlags::HEAPTYPE) { let object_new = vm.ctx.types.object_type.slots.new.load(); let cls_new = cls.slots.new.load(); - if let (Some(cls_new_fn), Some(obj_new_fn)) = (cls_new, object_new) + let object_alloc = vm.ctx.types.object_type.slots.alloc.load(); + let cls_alloc = cls.slots.alloc.load(); + if let (Some(cls_new_fn), Some(obj_new_fn), Some(cls_alloc_fn), Some(obj_alloc_fn)) = + (cls_new, object_new, cls_alloc, object_alloc) && cls_new_fn as usize == obj_new_fn as usize - && let Some(init) = cls.get_attr(identifier!(vm, __init__)) - && let Some(init_func) = init.downcast_ref::() - && init_func.can_specialize_call(nargs + 1) + && cls_alloc_fn as usize == obj_alloc_fn as usize { - let version = cls.tp_version_tag.load(Acquire); - if version != 0 { + let init = cls.get_attr(identifier!(vm, __init__)); + let mut version = cls.tp_version_tag.load(Acquire); + if version == 0 { + version = cls.assign_version_tag(); + } + if version == 0 { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } + if let Some(init) = init + && let Some(init_func) = init.downcast_ref_if_exact::(vm) + && init_func.is_simple_for_call_specialization() + && cls.cache_init_for_specialization(init_func.to_owned(), version, vm) + { unsafe { self.code .instructions @@ -7758,7 +8407,7 @@ impl ExecutingFrame<'_> { fn specialize_call_kw( &mut self, - _vm: &VirtualMachine, + vm: &VirtualMachine, nargs: u32, instr_idx: usize, cache_base: usize, @@ -7778,7 +8427,29 @@ impl ExecutingFrame<'_> { .is_some(); let callable = self.nth_value(nargs + 2); - if let Some(func) = callable.downcast_ref::() { + if let Some(func) = callable.downcast_ref_if_exact::(vm) { + if self.specialization_eval_frame_active(vm) { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } + if !func.is_optimized_for_call_specialization() { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } let version = func.get_version_for_current_state(); if version == 0 { unsafe { @@ -7802,11 +8473,55 @@ impl ExecutingFrame<'_> { } if !self_or_null_is_some - && let Some(bound_method) = callable.downcast_ref::() - && let Some(func) = bound_method.function_obj().downcast_ref::() + && let Some(bound_method) = callable.downcast_ref_if_exact::(vm) { - let version = func.get_version_for_current_state(); - if version == 0 { + if let Some(func) = bound_method + .function_obj() + .downcast_ref_if_exact::(vm) + { + if self.specialization_eval_frame_active(vm) { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } + if !func.is_optimized_for_call_specialization() { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } + let version = func.get_version_for_current_state(); + if version == 0 { + unsafe { + self.code.instructions.write_adaptive_counter( + cache_base, + bytecode::adaptive_counter_backoff( + self.code.instructions.read_adaptive_counter(cache_base), + ), + ); + } + return; + } + unsafe { + self.code + .instructions + .write_cache_u32(cache_base + 1, version); + } + self.specialize_at(instr_idx, cache_base, Instruction::CallKwBoundMethod); + } else { + // Match CPython: bound methods wrapping non-Python callables + // are not specialized as CALL_KW_NON_PY. unsafe { self.code.instructions.write_adaptive_counter( cache_base, @@ -7815,14 +8530,7 @@ impl ExecutingFrame<'_> { ), ); } - return; } - unsafe { - self.code - .instructions - .write_cache_u32(cache_base + 1, version); - } - self.specialize_at(instr_idx, cache_base, Instruction::CallKwBoundMethod); return; } @@ -7830,7 +8538,7 @@ impl ExecutingFrame<'_> { self.specialize_at(instr_idx, cache_base, Instruction::CallKwNonPy); } - fn specialize_send(&mut self, instr_idx: usize, cache_base: usize) { + fn specialize_send(&mut self, vm: &VirtualMachine, instr_idx: usize, cache_base: usize) { if !matches!( self.code.instructions.read_op(instr_idx), Instruction::Send { .. } @@ -7839,7 +8547,9 @@ impl ExecutingFrame<'_> { } // Stack: [receiver, val] — receiver is at position 1 let receiver = self.nth_value(1); - if self.builtin_coro(receiver).is_some() { + let is_exact_gen_or_coro = receiver.downcast_ref_if_exact::(vm).is_some() + || receiver.downcast_ref_if_exact::(vm).is_some(); + if is_exact_gen_or_coro && !self.specialization_eval_frame_active(vm) { self.specialize_at(instr_idx, cache_base, Instruction::SendGen); } else { unsafe { @@ -7895,7 +8605,7 @@ impl ExecutingFrame<'_> { fn specialize_compare_op( &mut self, vm: &VirtualMachine, - _op: bytecode::ComparisonOperator, + op: bytecode::ComparisonOperator, instr_idx: usize, cache_base: usize, ) { @@ -7908,16 +8618,25 @@ impl ExecutingFrame<'_> { let b = self.top_value(); let a = self.nth_value(1); - let new_op = if a.downcast_ref_if_exact::(vm).is_some() - && b.downcast_ref_if_exact::(vm).is_some() - { - Some(Instruction::CompareOpInt) + let new_op = if let (Some(a_int), Some(b_int)) = ( + a.downcast_ref_if_exact::(vm), + b.downcast_ref_if_exact::(vm), + ) { + if specialization_compact_int_value(a_int, vm).is_some() + && specialization_compact_int_value(b_int, vm).is_some() + { + Some(Instruction::CompareOpInt) + } else { + None + } } else if a.downcast_ref_if_exact::(vm).is_some() && b.downcast_ref_if_exact::(vm).is_some() { Some(Instruction::CompareOpFloat) } else if a.downcast_ref_if_exact::(vm).is_some() && b.downcast_ref_if_exact::(vm).is_some() + && (op == bytecode::ComparisonOperator::Equal + || op == bytecode::ComparisonOperator::NotEqual) { Some(Instruction::CompareOpStr) } else { @@ -7935,6 +8654,12 @@ impl ExecutingFrame<'_> { .into() } + /// Recover the BinaryOperator from the instruction arg byte. + /// `replace_op` preserves the arg byte, so the original op remains accessible. + fn binary_op_from_arg(&self, arg: bytecode::OpArg) -> bytecode::BinaryOperator { + bytecode::BinaryOperator::try_from(u32::from(arg)).unwrap_or(bytecode::BinaryOperator::Add) + } + fn specialize_to_bool(&mut self, vm: &VirtualMachine, instr_idx: usize, cache_base: usize) { if !matches!( self.code.instructions.read_op(instr_idx), @@ -7955,7 +8680,8 @@ impl ExecutingFrame<'_> { Some(Instruction::ToBoolList) } else if cls.is(PyStr::class(&vm.ctx)) { Some(Instruction::ToBoolStr) - } else if cls.slots.as_number.boolean.load().is_none() + } else if cls.slots.flags.has_feature(PyTypeFlags::HEAPTYPE) + && cls.slots.as_number.boolean.load().is_none() && cls.slots.as_mapping.length.load().is_none() && cls.slots.as_sequence.length.load().is_none() { @@ -7989,7 +8715,13 @@ impl ExecutingFrame<'_> { self.commit_specialization(instr_idx, cache_base, new_op); } - fn specialize_for_iter(&mut self, vm: &VirtualMachine, instr_idx: usize, cache_base: usize) { + fn specialize_for_iter( + &mut self, + vm: &VirtualMachine, + jump_delta: u32, + instr_idx: usize, + cache_base: usize, + ) { if !matches!( self.code.instructions.read_op(instr_idx), Instruction::ForIter { .. } @@ -8004,7 +8736,11 @@ impl ExecutingFrame<'_> { Some(Instruction::ForIterList) } else if iter.downcast_ref_if_exact::(vm).is_some() { Some(Instruction::ForIterTuple) - } else if iter.downcast_ref_if_exact::(vm).is_some() { + } else if iter.downcast_ref_if_exact::(vm).is_some() + && jump_delta <= i16::MAX as u32 + && self.for_iter_has_end_for_shape(instr_idx, jump_delta) + && !self.specialization_eval_frame_active(vm) + { Some(Instruction::ForIterGen) } else { None @@ -8013,6 +8749,69 @@ impl ExecutingFrame<'_> { self.commit_specialization(instr_idx, cache_base, new_op); } + #[inline] + fn specialization_eval_frame_active(&self, vm: &VirtualMachine) -> bool { + vm.use_tracing.get() + } + + #[inline] + fn specialization_has_datastack_space_for_func( + &self, + vm: &VirtualMachine, + func: &Py, + ) -> bool { + self.specialization_has_datastack_space_for_func_with_extra(vm, func, 0) + } + + #[inline] + fn specialization_has_datastack_space_for_func_with_extra( + &self, + vm: &VirtualMachine, + func: &Py, + extra_bytes: usize, + ) -> bool { + match func.datastack_frame_size_bytes() { + Some(frame_size) => frame_size + .checked_add(extra_bytes) + .is_some_and(|size| vm.datastack_has_space(size)), + None => extra_bytes == 0 || vm.datastack_has_space(extra_bytes), + } + } + + #[inline] + fn specialization_call_recursion_guard(&self, vm: &VirtualMachine) -> bool { + self.specialization_call_recursion_guard_with_extra_frames(vm, 0) + } + + #[inline] + fn specialization_call_recursion_guard_with_extra_frames( + &self, + vm: &VirtualMachine, + extra_frames: usize, + ) -> bool { + vm.current_recursion_depth() + .saturating_add(1) + .saturating_add(extra_frames) + >= vm.recursion_limit.get() + } + + #[inline] + fn for_iter_has_end_for_shape(&self, instr_idx: usize, jump_delta: u32) -> bool { + let target_idx = instr_idx + + 1 + + Instruction::ForIter { + delta: Arg::marker(), + } + .cache_entries() + + jump_delta as usize; + self.code.instructions.get(target_idx).is_some_and(|unit| { + matches!( + unit.op, + Instruction::EndFor | Instruction::InstrumentedEndFor + ) + }) + } + /// Handle iterator exhaustion in specialized FOR_ITER handlers. /// Skips END_FOR if present at target and jumps. fn for_iter_jump_on_exhausted(&mut self, target: bytecode::Label) { @@ -8117,10 +8916,16 @@ impl ExecutingFrame<'_> { let obj = self.nth_value(1); let idx = self.top_value(); - let new_op = if obj.downcast_ref_if_exact::(vm).is_some() - && idx.downcast_ref_if_exact::(vm).is_some() - { - Some(Instruction::StoreSubscrListInt) + let new_op = if let (Some(list), Some(int_idx)) = ( + obj.downcast_ref_if_exact::(vm), + idx.downcast_ref_if_exact::(vm), + ) { + let list_len = list.borrow_vec().len(); + if specialization_nonnegative_compact_index(int_idx, vm).is_some_and(|i| i < list_len) { + Some(Instruction::StoreSubscrListInt) + } else { + None + } } else if obj.downcast_ref_if_exact::(vm).is_some() { Some(Instruction::StoreSubscrDict) } else { @@ -8154,6 +8959,7 @@ impl ExecutingFrame<'_> { fn specialize_unpack_sequence( &mut self, vm: &VirtualMachine, + expected_count: u32, instr_idx: usize, cache_base: usize, ) { @@ -8165,13 +8971,19 @@ impl ExecutingFrame<'_> { } let obj = self.top_value(); let new_op = if let Some(tuple) = obj.downcast_ref_if_exact::(vm) { - if tuple.len() == 2 { + if tuple.len() != expected_count as usize { + None + } else if expected_count == 2 { Some(Instruction::UnpackSequenceTwoTuple) } else { Some(Instruction::UnpackSequenceTuple) } - } else if obj.downcast_ref_if_exact::(vm).is_some() { - Some(Instruction::UnpackSequenceList) + } else if let Some(list) = obj.downcast_ref_if_exact::(vm) { + if list.borrow_vec().len() == expected_count as usize { + Some(Instruction::UnpackSequenceList) + } else { + None + } } else { None }; @@ -8458,19 +9270,19 @@ impl ExecutingFrame<'_> { } bytecode::IntrinsicFunction1::TypeVar => { let type_var: PyObjectRef = - typing::TypeVar::new(vm, arg.clone(), vm.ctx.none(), vm.ctx.none()) + _typing::TypeVar::new(vm, arg.clone(), vm.ctx.none(), vm.ctx.none()) .into_ref(&vm.ctx) .into(); Ok(type_var) } bytecode::IntrinsicFunction1::ParamSpec => { - let param_spec: PyObjectRef = typing::ParamSpec::new(arg.clone(), vm) + let param_spec: PyObjectRef = _typing::ParamSpec::new(arg.clone(), vm) .into_ref(&vm.ctx) .into(); Ok(param_spec) } bytecode::IntrinsicFunction1::TypeVarTuple => { - let type_var_tuple: PyObjectRef = typing::TypeVarTuple::new(arg.clone(), vm) + let type_var_tuple: PyObjectRef = _typing::TypeVarTuple::new(arg.clone(), vm) .into_ref(&vm.ctx) .into(); Ok(type_var_tuple) @@ -8500,10 +9312,10 @@ impl ExecutingFrame<'_> { .map_err(|_| vm.new_type_error("Type params must be a tuple."))? }; - let name = name.downcast::().map_err(|_| { - vm.new_type_error("TypeAliasType name must be a string".to_owned()) - })?; - let type_alias = typing::TypeAliasType::new(name, type_params, compute_value); + let name = name + .downcast::() + .map_err(|_| vm.new_type_error("TypeAliasType name must be a string"))?; + let type_alias = _typing::TypeAliasType::new(name, type_params, compute_value); Ok(type_alias.into_ref(&vm.ctx).into()) } bytecode::IntrinsicFunction1::ListToTuple => { @@ -8549,7 +9361,7 @@ impl ExecutingFrame<'_> { ) -> PyResult { match func { bytecode::IntrinsicFunction2::SetTypeparamDefault => { - crate::stdlib::typing::set_typeparam_default(arg1, arg2, vm) + crate::stdlib::_typing::set_typeparam_default(arg1, arg2, vm) } bytecode::IntrinsicFunction2::SetFunctionTypeParams => { // arg1 is the function, arg2 is the type params tuple @@ -8559,14 +9371,14 @@ impl ExecutingFrame<'_> { } bytecode::IntrinsicFunction2::TypeVarWithBound => { let type_var: PyObjectRef = - typing::TypeVar::new(vm, arg1.clone(), arg2, vm.ctx.none()) + _typing::TypeVar::new(vm, arg1.clone(), arg2, vm.ctx.none()) .into_ref(&vm.ctx) .into(); Ok(type_var) } bytecode::IntrinsicFunction2::TypeVarWithConstraint => { let type_var: PyObjectRef = - typing::TypeVar::new(vm, arg1.clone(), vm.ctx.none(), arg2) + _typing::TypeVar::new(vm, arg1.clone(), vm.ctx.none(), arg2) .into_ref(&vm.ctx) .into(); Ok(type_var) @@ -8644,24 +9456,26 @@ impl fmt::Debug for Frame { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // SAFETY: Debug is best-effort; concurrent mutation is unlikely // and would only affect debug output. - let localsplus = unsafe { &*self.localsplus.get() }; - let stack_str = localsplus - .stack_as_slice() - .iter() - .fold(String::new(), |mut s, slot| { - match slot { - Some(elem) if elem.downcastable::() => { - s.push_str("\n > {frame}"); - } - Some(elem) => { - core::fmt::write(&mut s, format_args!("\n > {elem:?}")).unwrap(); - } - None => { - s.push_str("\n > NULL"); + let iframe = unsafe { &*self.iframe.get() }; + let stack_str = + iframe + .localsplus + .stack_as_slice() + .iter() + .fold(String::new(), |mut s, slot| { + match slot { + Some(elem) if elem.downcastable::() => { + s.push_str("\n > {frame}"); + } + Some(elem) => { + core::fmt::write(&mut s, format_args!("\n > {elem:?}")).unwrap(); + } + None => { + s.push_str("\n > NULL"); + } } - } - s - }); + s + }); // TODO: fix this up write!( f, diff --git a/crates/vm/src/function/method.rs b/crates/vm/src/function/method.rs index 52624cbbf86..295e4d89adf 100644 --- a/crates/vm/src/function/method.rs +++ b/crates/vm/src/function/method.rs @@ -12,11 +12,11 @@ bitflags::bitflags! { // METH_XXX flags in CPython #[derive(Copy, Clone, Debug, PartialEq)] pub struct PyMethodFlags: u32 { - // const VARARGS = 0x0001; - // const KEYWORDS = 0x0002; + const VARARGS = 0x0001; + const KEYWORDS = 0x0002; // METH_NOARGS and METH_O must not be combined with the flags above. - // const NOARGS = 0x0004; - // const O = 0x0008; + const NOARGS = 0x0004; + const O = 0x0008; // METH_CLASS and METH_STATIC are a little different; these control // the construction of methods for a class. These cannot be used for @@ -31,7 +31,7 @@ bitflags::bitflags! { // const COEXIST = 0x0040; // if not Py_LIMITED_API - // const FASTCALL = 0x0080; + const FASTCALL = 0x0080; // This bit is preserved for Stackless Python // const STACKLESS = 0x0100; @@ -123,6 +123,7 @@ impl PyMethodDef { zelf: None, value: self, module: None, + _method_def_owner: None, } } @@ -144,6 +145,7 @@ impl PyMethodDef { zelf: Some(obj), value: self, module: None, + _method_def_owner: None, }, class, } @@ -162,6 +164,7 @@ impl PyMethodDef { zelf: Some(obj), value: self, module: None, + _method_def_owner: None, }; PyRef::new_ref( function, @@ -211,12 +214,13 @@ impl PyMethodDef { class: &'static Py, ) -> PyRef { debug_assert!(self.flags.contains(PyMethodFlags::STATIC)); - // Set zelf to the class, matching CPython's m_self = type for static methods. + // Set zelf to the class (m_self = type for static methods). // Callable::call skips prepending when STATIC flag is set. let func = PyNativeFunction { zelf: Some(class.to_owned().into()), value: self, module: None, + _method_def_owner: None, }; PyNativeMethod { func, class }.into_ref(ctx) } @@ -293,14 +297,12 @@ impl Py { } pub fn build_function(&self, vm: &VirtualMachine) -> PyRef { - let function = unsafe { self.method() }.to_function(); - let dict = vm.ctx.new_dict(); - dict.set_item("__method_def__", self.to_owned().into(), vm) - .unwrap(); + let mut function = unsafe { self.method() }.to_function(); + function._method_def_owner = Some(self.to_owned().into()); PyRef::new_ref( function, vm.ctx.types.builtin_function_or_method_type.to_owned(), - Some(dict), + None, ) } @@ -309,14 +311,12 @@ impl Py { class: &'static Py, vm: &VirtualMachine, ) -> PyRef { - let function = unsafe { self.method() }.to_method(class, &vm.ctx); - let dict = vm.ctx.new_dict(); - dict.set_item("__method_def__", self.to_owned().into(), vm) - .unwrap(); + let mut function = unsafe { self.method() }.to_method(class, &vm.ctx); + function._method_def_owner = Some(self.to_owned().into()); PyRef::new_ref( function, vm.ctx.types.method_descriptor_type.to_owned(), - Some(dict), + None, ) } } diff --git a/crates/vm/src/function/mod.rs b/crates/vm/src/function/mod.rs index 15048919593..4be94e3f0be 100644 --- a/crates/vm/src/function/mod.rs +++ b/crates/vm/src/function/mod.rs @@ -8,6 +8,7 @@ mod getset; mod method; mod number; mod protocol; +mod time; pub use argument::{ ArgumentError, FromArgOptional, FromArgs, FuncArgs, IntoFuncArgs, KwArgs, OptionalArg, @@ -23,6 +24,7 @@ pub(super) use getset::{IntoPyGetterFunc, IntoPySetterFunc, PyGetterFunc, PySett pub use method::{HeapMethodDef, PyMethodDef, PyMethodFlags}; pub use number::{ArgIndex, ArgIntoBool, ArgIntoComplex, ArgIntoFloat, ArgPrimitiveIndex, ArgSize}; pub use protocol::{ArgCallable, ArgIterable, ArgMapping, ArgSequence}; +pub use time::TimeoutSeconds; use crate::{PyObject, PyResult, VirtualMachine, builtins::PyStr, convert::TryFromBorrowedObject}; use builtin::{BorrowedParam, OwnedParam, RefParam}; diff --git a/crates/vm/src/function/time.rs b/crates/vm/src/function/time.rs new file mode 100644 index 00000000000..29f14495d14 --- /dev/null +++ b/crates/vm/src/function/time.rs @@ -0,0 +1,34 @@ +use crate::{PyObjectRef, PyResult, TryFromObject, VirtualMachine}; + +/// A Python timeout value that accepts both `float` and `int`. +/// +/// `TimeoutSeconds` implements `FromArgs` so that a built-in function can accept +/// timeout parameters given as either `float` or `int`, normalizing them to `f64`. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TimeoutSeconds { + value: f64, +} + +impl TimeoutSeconds { + pub const fn new(secs: f64) -> Self { + Self { value: secs } + } + + #[inline] + pub fn to_secs_f64(self) -> f64 { + self.value + } +} + +impl TryFromObject for TimeoutSeconds { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + let value = match super::Either::::try_from_object(vm, obj)? { + super::Either::A(f) => f, + super::Either::B(i) => i as f64, + }; + if value.is_nan() { + return Err(vm.new_value_error("Invalid value NaN (not a number)".to_owned())); + } + Ok(Self { value }) + } +} diff --git a/crates/vm/src/gc_state.rs b/crates/vm/src/gc_state.rs index e8e83bba49c..d86c3d4d560 100644 --- a/crates/vm/src/gc_state.rs +++ b/crates/vm/src/gc_state.rs @@ -1,7 +1,6 @@ //! Garbage Collection State and Algorithm //! -//! This module implements CPython-compatible generational garbage collection -//! for RustPython, using an intrusive doubly-linked list approach. +//! Generational garbage collection using an intrusive doubly-linked list. use crate::common::linked_list::LinkedList; use crate::common::lock::{PyMutex, PyRwLock}; @@ -11,6 +10,16 @@ use core::ptr::NonNull; use core::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}; use std::collections::HashSet; +#[cfg(not(target_arch = "wasm32"))] +fn elapsed_secs(start: &std::time::Instant) -> f64 { + start.elapsed().as_secs_f64() +} + +#[cfg(target_arch = "wasm32")] +fn elapsed_secs(_start: &()) -> f64 { + 0.0 +} + bitflags::bitflags! { /// GC debug flags (see Include/internal/pycore_gc.h) #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] @@ -28,12 +37,23 @@ bitflags::bitflags! { } } +/// Result from a single collection run +#[derive(Debug, Default)] +pub struct CollectResult { + pub collected: usize, + pub uncollectable: usize, + pub candidates: usize, + pub duration: f64, +} + /// Statistics for a single generation (gc_generation_stats) #[derive(Debug, Default)] pub struct GcStats { pub collections: usize, pub collected: usize, pub uncollectable: usize, + pub candidates: usize, + pub duration: f64, } /// A single GC generation with intrusive linked list @@ -55,6 +75,8 @@ impl GcGeneration { collections: 0, collected: 0, uncollectable: 0, + candidates: 0, + duration: 0.0, }), } } @@ -77,14 +99,24 @@ impl GcGeneration { collections: guard.collections, collected: guard.collected, uncollectable: guard.uncollectable, + candidates: guard.candidates, + duration: guard.duration, } } - pub fn update_stats(&self, collected: usize, uncollectable: usize) { + pub fn update_stats( + &self, + collected: usize, + uncollectable: usize, + candidates: usize, + duration: f64, + ) { let mut guard = self.stats.lock(); guard.collections += 1; guard.collected += collected; guard.uncollectable += uncollectable; + guard.candidates += candidates; + guard.duration += duration; } /// Reset the stats mutex to unlocked state after fork(). @@ -340,25 +372,30 @@ impl GcState { } /// Perform garbage collection on the given generation - pub fn collect(&self, generation: usize) -> (usize, usize) { + pub fn collect(&self, generation: usize) -> CollectResult { self.collect_inner(generation, false) } /// Force collection even if GC is disabled (for manual gc.collect() calls) - pub fn collect_force(&self, generation: usize) -> (usize, usize) { + pub fn collect_force(&self, generation: usize) -> CollectResult { self.collect_inner(generation, true) } - fn collect_inner(&self, generation: usize, force: bool) -> (usize, usize) { + fn collect_inner(&self, generation: usize, force: bool) -> CollectResult { if !force && !self.is_enabled() { - return (0, 0); + return CollectResult::default(); } // Try to acquire the collecting lock let Some(_guard) = self.collecting.try_lock() else { - return (0, 0); + return CollectResult::default(); }; + #[cfg(not(target_arch = "wasm32"))] + let start_time = std::time::Instant::now(); + #[cfg(target_arch = "wasm32")] + let start_time = (); + // Memory barrier to ensure visibility of all reference count updates // from other threads before we start analyzing the object graph. core::sync::atomic::fence(Ordering::SeqCst); @@ -386,11 +423,24 @@ impl GcState { } if collecting.is_empty() { - self.generations[0].count.store(0, Ordering::SeqCst); - self.generations[generation].update_stats(0, 0); - return (0, 0); + // Reset counts for generations whose objects were promoted away. + // For gen2 (oldest), survivors stay in-place so don't reset gen2 count. + let reset_end = if generation >= 2 { 2 } else { generation + 1 }; + for i in 0..reset_end { + self.generations[i].count.store(0, Ordering::SeqCst); + } + let duration = elapsed_secs(&start_time); + self.generations[generation].update_stats(0, 0, 0, duration); + return CollectResult { + collected: 0, + uncollectable: 0, + candidates: 0, + duration, + }; } + let candidates = collecting.len(); + if debug.contains(GcDebugFlags::STATS) { eprintln!( "gc: collecting {} objects from generations 0..={}", @@ -486,9 +536,18 @@ impl GcState { if unreachable.is_empty() { drop(gen_locks); self.promote_survivors(generation, &survivor_refs); - self.generations[0].count.store(0, Ordering::SeqCst); - self.generations[generation].update_stats(0, 0); - return (0, 0); + let reset_end = if generation >= 2 { 2 } else { generation + 1 }; + for i in 0..reset_end { + self.generations[i].count.store(0, Ordering::SeqCst); + } + let duration = elapsed_secs(&start_time); + self.generations[generation].update_stats(0, 0, candidates, duration); + return CollectResult { + collected: 0, + uncollectable: 0, + candidates, + duration, + }; } // Release read locks before finalization phase. @@ -498,9 +557,18 @@ impl GcState { if unreachable_refs.is_empty() { self.promote_survivors(generation, &survivor_refs); - self.generations[0].count.store(0, Ordering::SeqCst); - self.generations[generation].update_stats(0, 0); - return (0, 0); + let reset_end = if generation >= 2 { 2 } else { generation + 1 }; + for i in 0..reset_end { + self.generations[i].count.store(0, Ordering::SeqCst); + } + let duration = elapsed_secs(&start_time); + self.generations[generation].update_stats(0, 0, candidates, duration); + return CollectResult { + collected: 0, + uncollectable: 0, + candidates, + duration, + }; } // 6b: Record initial strong counts (for resurrection detection) @@ -594,15 +662,25 @@ impl GcState { }; // Promote survivors to next generation BEFORE tp_clear. - // This matches CPython's order (move_legacy_finalizer_reachable → delete_garbage) - // and ensures survivor_refs are dropped before tp_clear, so reachable objects - // (e.g. LateFin) aren't kept alive beyond the deferred-drop phase. + // move_legacy_finalizer_reachable → delete_garbage order ensures + // survivor_refs are dropped before tp_clear, so reachable objects + // aren't kept alive beyond the deferred-drop phase. self.promote_survivors(generation, &survivor_refs); drop(survivor_refs); // Resurrected objects stay tracked — just drop our references drop(resurrected); + if debug.contains(GcDebugFlags::COLLECTABLE) { + for obj in &truly_dead { + eprintln!( + "gc: collectable <{} {:p}>", + obj.class().name(), + obj.as_ref() + ); + } + } + if debug.contains(GcDebugFlags::SAVEALL) { let mut garbage_guard = self.garbage.lock(); for obj_ref in truly_dead.iter() { @@ -624,12 +702,22 @@ impl GcState { }); } - // Reset gen0 count - self.generations[0].count.store(0, Ordering::SeqCst); + // Reset counts for generations whose objects were promoted away. + // For gen2 (oldest), survivors stay in-place so don't reset gen2 count. + let reset_end = if generation >= 2 { 2 } else { generation + 1 }; + for i in 0..reset_end { + self.generations[i].count.store(0, Ordering::SeqCst); + } - self.generations[generation].update_stats(collected, 0); + let duration = elapsed_secs(&start_time); + self.generations[generation].update_stats(collected, 0, candidates, duration); - (collected, 0) + CollectResult { + collected, + uncollectable: 0, + candidates, + duration, + } } /// Promote surviving objects to the next generation. diff --git a/crates/vm/src/import.rs b/crates/vm/src/import.rs index 4e89052e1a8..9d015c8f3b6 100644 --- a/crates/vm/src/import.rs +++ b/crates/vm/src/import.rs @@ -374,7 +374,7 @@ pub(crate) fn import_module_level( vm: &VirtualMachine, ) -> PyResult { if level < 0 { - return Err(vm.new_value_error("level must be >= 0".to_owned())); + return Err(vm.new_value_error("level must be >= 0")); } let name_str = match name.to_str() { @@ -411,14 +411,14 @@ pub(crate) fn import_module_level( let package = calc_package(Some(globals_ref), vm)?; if package.is_empty() { return Err(vm.new_import_error( - "attempted relative import with no known parent package".to_owned(), + "attempted relative import with no known parent package", vm.ctx.new_utf8_str(""), )); } resolve_name(name_str, &package, level as usize, vm)? } else { if name_str.is_empty() { - return Err(vm.new_value_error("Empty module name".to_owned())); + return Err(vm.new_value_error("Empty module name")); } name_str.to_owned() }; @@ -500,7 +500,7 @@ fn resolve_name(name: &str, package: &str, level: usize, vm: &VirtualMachine) -> let parts: Vec<&str> = package.rsplitn(level, '.').collect(); if parts.len() < level { return Err(vm.new_import_error( - "attempted relative import beyond top-level package".to_owned(), + "attempted relative import beyond top-level package", vm.ctx.new_utf8_str(name), )); } @@ -517,7 +517,7 @@ fn resolve_name(name: &str, package: &str, level: usize, vm: &VirtualMachine) -> fn calc_package(globals: Option<&PyObjectRef>, vm: &VirtualMachine) -> PyResult { let globals = globals.ok_or_else(|| { vm.new_import_error( - "attempted relative import with no known parent package".to_owned(), + "attempted relative import with no known parent package", vm.ctx.new_utf8_str(""), ) })?; @@ -531,7 +531,7 @@ fn calc_package(globals: Option<&PyObjectRef>, vm: &VirtualMachine) -> PyResult< let pkg_str: PyUtf8StrRef = pkg .clone() .downcast() - .map_err(|_| vm.new_type_error("package must be a string".to_owned()))?; + .map_err(|_| vm.new_type_error("package must be a string"))?; // Warn if __package__ != __spec__.parent if let Some(ref spec) = spec && !vm.is_none(spec) @@ -572,7 +572,7 @@ fn calc_package(globals: Option<&PyObjectRef>, vm: &VirtualMachine) -> PyResult< { let parent_str: PyUtf8StrRef = parent .downcast() - .map_err(|_| vm.new_type_error("package set to non-string".to_owned()))?; + .map_err(|_| vm.new_type_error("package set to non-string"))?; return Ok(parent_str.as_str().to_owned()); } @@ -592,13 +592,13 @@ fn calc_package(globals: Option<&PyObjectRef>, vm: &VirtualMachine) -> PyResult< let mod_name = globals.get_item("__name__", vm).map_err(|_| { vm.new_import_error( - "attempted relative import with no known parent package".to_owned(), + "attempted relative import with no known parent package", vm.ctx.new_utf8_str(""), ) })?; let mod_name_str: PyUtf8StrRef = mod_name .downcast() - .map_err(|_| vm.new_type_error("__name__ must be a string".to_owned()))?; + .map_err(|_| vm.new_type_error("__name__ must be a string"))?; let mut package = mod_name_str.as_str().to_owned(); // If not a package (no __path__), strip last component. // Uses rpartition('.')[0] semantics: returns empty string when no dot. diff --git a/crates/vm/src/object/core.rs b/crates/vm/src/object/core.rs index 13927952604..a8d7c09da89 100644 --- a/crates/vm/src/object/core.rs +++ b/crates/vm/src/object/core.rs @@ -188,26 +188,32 @@ pub(super) unsafe fn default_dealloc(obj: *mut PyObject) { ); } - // Extract child references before deallocation to break circular refs (tp_clear) + // Try to store in freelist for reuse BEFORE tp_clear, so that + // size-based freelists (e.g. PyTuple) can read the payload directly. + // Only exact base types (not heaptype or structseq subtypes) go into the freelist. + let typ = obj_ref.class(); + let pushed = if T::HAS_FREELIST + && typ.heaptype_ext.is_none() + && core::ptr::eq(typ, T::class(crate::vm::Context::genesis())) + { + unsafe { T::freelist_push(obj) } + } else { + false + }; + + // Extract child references to break circular refs (tp_clear). + // This runs regardless of freelist push — the object's children must be released. let mut edges = Vec::new(); if let Some(clear_fn) = vtable.clear { unsafe { clear_fn(obj, &mut edges) }; } - // Try to store in freelist for reuse; otherwise deallocate. - // Only exact types (not heaptype subclasses) go into the freelist, - // because the pop site assumes the cached typ matches the base type. - let pushed = if T::HAS_FREELIST && obj_ref.class().heaptype_ext.is_none() { - unsafe { T::freelist_push(obj) } - } else { - false - }; if !pushed { - drop(unsafe { Box::from_raw(obj as *mut PyInner) }); + // Deallocate the object memory (handles ObjExt prefix if present) + unsafe { PyInner::dealloc(obj as *mut PyInner) }; } // Drop child references - may trigger recursive destruction. - // The object is already deallocated, so circular refs are broken. drop(edges); // Trashcan: decrement depth and process deferred objects at outermost level @@ -286,6 +292,52 @@ unsafe impl Link for GcLink { } } +/// Extension fields for objects that need dict or member slots. +/// Allocated as a prefix before PyInner when needed (prefix allocation pattern). +/// Access via `PyInner::ext_ref()` using negative offset from the object pointer. +/// +/// align(8) ensures size_of::() is always a multiple of 8, +/// so the offset from Layout::extend equals size_of::() for any +/// PyInner alignment (important on wasm32 where pointers are 4 bytes +/// but some payloads like PyWeak have align 8 due to i64 fields). +#[repr(C, align(8))] +pub(super) struct ObjExt { + pub(super) dict: Option, + pub(super) slots: Box<[PyRwLock>]>, +} + +impl ObjExt { + fn new(dict: Option, member_count: usize) -> Self { + Self { + dict: dict.map(InstanceDict::new), + slots: core::iter::repeat_with(|| PyRwLock::new(None)) + .take(member_count) + .collect_vec() + .into_boxed_slice(), + } + } +} + +impl fmt::Debug for ObjExt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[ObjExt]") + } +} + +/// Precomputed offset constants for prefix allocation. +/// All prefix components are align(8) and their sizes are multiples of 8, +/// so Layout::extend adds no inter-padding. +const EXT_OFFSET: usize = core::mem::size_of::(); +const WEAKREF_OFFSET: usize = core::mem::size_of::(); + +const _: () = + assert!(core::mem::size_of::().is_multiple_of(core::mem::align_of::())); +const _: () = assert!(core::mem::align_of::() >= core::mem::align_of::>()); +const _: () = assert!( + core::mem::size_of::().is_multiple_of(core::mem::align_of::()) +); +const _: () = assert!(core::mem::align_of::() >= core::mem::align_of::>()); + /// This is an actual python object. It consists of a `typ` which is the /// python class, and carries some rust payload optionally. This rust /// payload can be a rust float or rust int in case of float and int objects. @@ -302,14 +354,65 @@ pub(super) struct PyInner { pub(super) gc_pointers: Pointers, pub(super) typ: PyAtomicRef, // __class__ member - pub(super) dict: Option, - pub(super) weak_list: WeakRefList, - pub(super) slots: Box<[PyRwLock>]>, pub(super) payload: T, } pub(crate) const SIZEOF_PYOBJECT_HEAD: usize = core::mem::size_of::>(); +impl PyInner { + /// Read type flags and member_count via raw pointers to avoid Stacked Borrows + /// violations during bootstrap, where type objects have self-referential typ pointers. + #[inline(always)] + fn read_type_flags(&self) -> (crate::types::PyTypeFlags, usize) { + let typ_ptr = self.typ.load_raw(); + let slots = unsafe { core::ptr::addr_of!((*typ_ptr).0.payload.slots) }; + let flags = unsafe { core::ptr::addr_of!((*slots).flags).read() }; + let member_count = unsafe { core::ptr::addr_of!((*slots).member_count).read() }; + (flags, member_count) + } + + /// Access the ObjExt prefix at a negative offset from this PyInner. + /// Returns None if this object was allocated without dict/slots. + /// + /// Layout: [ObjExt?][WeakRefList?][PyInner] + /// ObjExt offset depends on whether WeakRefList is also present. + #[inline(always)] + pub(super) fn ext_ref(&self) -> Option<&ObjExt> { + let (flags, member_count) = self.read_type_flags(); + let has_ext = flags.has_feature(crate::types::PyTypeFlags::HAS_DICT) || member_count > 0; + if !has_ext { + return None; + } + let has_weakref = flags.has_feature(crate::types::PyTypeFlags::HAS_WEAKREF); + let offset = if has_weakref { + WEAKREF_OFFSET + EXT_OFFSET + } else { + EXT_OFFSET + }; + let self_addr = (self as *const Self as *const u8).addr(); + let ext_ptr = core::ptr::with_exposed_provenance::(self_addr.wrapping_sub(offset)); + Some(unsafe { &*ext_ptr }) + } + + /// Access the WeakRefList prefix at a fixed negative offset from this PyInner. + /// Returns None if the type does not support weakrefs. + /// + /// Layout: [ObjExt?][WeakRefList?][PyInner] + /// WeakRefList is always immediately before PyInner (fixed WEAKREF_OFFSET). + #[inline(always)] + pub(super) fn weakref_list_ref(&self) -> Option<&WeakRefList> { + let (flags, _) = self.read_type_flags(); + if !flags.has_feature(crate::types::PyTypeFlags::HAS_WEAKREF) { + return None; + } + let self_addr = (self as *const Self as *const u8).addr(); + let ptr = core::ptr::with_exposed_provenance::( + self_addr.wrapping_sub(WEAKREF_OFFSET), + ); + Some(unsafe { &*ptr }) + } +} + impl fmt::Debug for PyInner { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "[PyObject {:?}]", &self.payload) @@ -397,6 +500,7 @@ pub(crate) fn reset_weakref_locks_after_fork() { // === WeakRefList: inline on every object (tp_weaklist) === +#[repr(C)] pub(super) struct WeakRefList { /// Head of the intrusive doubly-linked list of weakrefs. head: PyAtomic<*mut Py>, @@ -764,7 +868,8 @@ impl PyWeak { } let obj = unsafe { &*obj_ptr }; - let wrl = &obj.0.weak_list; + // Safety: if a weakref exists pointing to this object, weakref prefix must be present + let wrl = obj.0.weakref_list_ref().unwrap(); // Compute our Py node pointer from payload address let offset = std::mem::offset_of!(PyInner, payload); @@ -837,29 +942,163 @@ impl InstanceDict { pub fn replace(&self, d: PyDictRef) -> PyDictRef { core::mem::replace(&mut self.d.write(), d) } + + /// Consume the InstanceDict and return the inner PyDictRef. + #[inline] + pub fn into_inner(self) -> PyDictRef { + self.d.into_inner() + } +} + +impl PyInner { + /// Deallocate a PyInner, handling optional prefix(es). + /// Layout: [ObjExt?][WeakRefList?][PyInner] + /// + /// # Safety + /// `ptr` must be a valid pointer from `PyInner::new` and must not be used after this call. + unsafe fn dealloc(ptr: *mut Self) { + unsafe { + let (flags, member_count) = (*ptr).read_type_flags(); + let has_ext = + flags.has_feature(crate::types::PyTypeFlags::HAS_DICT) || member_count > 0; + let has_weakref = flags.has_feature(crate::types::PyTypeFlags::HAS_WEAKREF); + + if has_ext || has_weakref { + // Reconstruct the same layout used in new() + let mut layout = core::alloc::Layout::from_size_align(0, 1).unwrap(); + + if has_ext { + layout = layout + .extend(core::alloc::Layout::new::()) + .unwrap() + .0; + } + if has_weakref { + layout = layout + .extend(core::alloc::Layout::new::()) + .unwrap() + .0; + } + let (combined, inner_offset) = + layout.extend(core::alloc::Layout::new::()).unwrap(); + let combined = combined.pad_to_align(); + + let alloc_ptr = (ptr as *mut u8).sub(inner_offset); + + // Drop PyInner (payload, typ, etc.) + core::ptr::drop_in_place(ptr); + + // Drop ObjExt if present (dict, slots) + if has_ext { + core::ptr::drop_in_place(alloc_ptr as *mut ObjExt); + } + // WeakRefList has no Drop (just raw pointers), no drop_in_place needed + + alloc::alloc::dealloc(alloc_ptr, combined); + } else { + drop(Box::from_raw(ptr)); + } + } + } } impl PyInner { - fn new(payload: T, typ: PyTypeRef, dict: Option) -> Box { + /// Allocate a new PyInner, optionally with prefix(es). + /// Returns a raw pointer to the PyInner (NOT the allocation start). + /// Layout: [ObjExt?][WeakRefList?][PyInner] + fn new(payload: T, typ: PyTypeRef, dict: Option) -> *mut Self { let member_count = typ.slots.member_count; - Box::new(Self { - ref_count: RefCount::new(), - vtable: PyObjVTable::of::(), - gc_bits: Radium::new(0), - gc_generation: Radium::new(GC_UNTRACKED), - gc_pointers: Pointers::new(), - typ: PyAtomicRef::from(typ), - dict: dict.map(InstanceDict::new), - weak_list: WeakRefList::new(), - payload, - slots: core::iter::repeat_with(|| PyRwLock::new(None)) - .take(member_count) - .collect_vec() - .into_boxed_slice(), - }) + let needs_ext = typ + .slots + .flags + .has_feature(crate::types::PyTypeFlags::HAS_DICT) + || member_count > 0; + let needs_weakref = typ + .slots + .flags + .has_feature(crate::types::PyTypeFlags::HAS_WEAKREF); + debug_assert!( + needs_ext || dict.is_none(), + "dict passed to type '{}' without HAS_DICT flag", + typ.name() + ); + + if needs_ext || needs_weakref { + // Build layout left-to-right: [ObjExt?][WeakRefList?][PyInner] + let mut layout = core::alloc::Layout::from_size_align(0, 1).unwrap(); + + let ext_start = if needs_ext { + let (combined, offset) = + layout.extend(core::alloc::Layout::new::()).unwrap(); + layout = combined; + Some(offset) + } else { + None + }; + + let weakref_start = if needs_weakref { + let (combined, offset) = layout + .extend(core::alloc::Layout::new::()) + .unwrap(); + layout = combined; + Some(offset) + } else { + None + }; + + let (combined, inner_offset) = + layout.extend(core::alloc::Layout::new::()).unwrap(); + let combined = combined.pad_to_align(); + + let alloc_ptr = unsafe { alloc::alloc::alloc(combined) }; + if alloc_ptr.is_null() { + alloc::alloc::handle_alloc_error(combined); + } + // Expose provenance so ext_ref()/weakref_list_ref() can reconstruct + alloc_ptr.expose_provenance(); + + unsafe { + if let Some(offset) = ext_start { + let ext_ptr = alloc_ptr.add(offset) as *mut ObjExt; + ext_ptr.write(ObjExt::new(dict, member_count)); + } + + if let Some(offset) = weakref_start { + let weakref_ptr = alloc_ptr.add(offset) as *mut WeakRefList; + weakref_ptr.write(WeakRefList::new()); + } + + let inner_ptr = alloc_ptr.add(inner_offset) as *mut Self; + inner_ptr.write(Self { + ref_count: RefCount::new(), + vtable: PyObjVTable::of::(), + gc_bits: Radium::new(0), + gc_generation: Radium::new(GC_UNTRACKED), + gc_pointers: Pointers::new(), + typ: PyAtomicRef::from(typ), + payload, + }); + inner_ptr + } + } else { + Box::into_raw(Box::new(Self { + ref_count: RefCount::new(), + vtable: PyObjVTable::of::(), + gc_bits: Radium::new(0), + gc_generation: Radium::new(GC_UNTRACKED), + gc_pointers: Pointers::new(), + typ: PyAtomicRef::from(typ), + payload, + })) + } } } +/// Returns the allocation layout for `PyInner`, for use in freelist Drop impls. +pub(crate) const fn pyinner_layout() -> core::alloc::Layout { + core::alloc::Layout::new::>() +} + /// Thread-local freelist storage for reusing object allocations. /// /// Wraps a `Vec<*mut PyObject>`. On thread teardown, `Drop` frees raw @@ -1075,9 +1314,29 @@ impl PyObjectRef { } impl PyObject { + /// Returns the WeakRefList if the type supports weakrefs (HAS_WEAKREF). + /// The WeakRefList is stored as a separate prefix before PyInner, + /// independent from ObjExt (dict/slots). #[inline(always)] - const fn weak_ref_list(&self) -> Option<&WeakRefList> { - Some(&self.0.weak_list) + fn weak_ref_list(&self) -> Option<&WeakRefList> { + self.0.weakref_list_ref() + } + + /// Returns the first weakref in the weakref list, if any. + pub(crate) fn get_weakrefs(&self) -> Option { + let wrl = self.weak_ref_list()?; + let _lock = weakref_lock::lock(self as *const PyObject as usize); + let head_ptr = wrl.head.load(Ordering::Relaxed); + if head_ptr.is_null() { + None + } else { + let head = unsafe { &*head_ptr }; + if head.0.ref_count.safe_inc() { + Some(unsafe { PyRef::from_raw(head_ptr) }.into()) + } else { + None + } + } } pub(crate) fn downgrade_with_weakref_typ_opt( @@ -1096,6 +1355,18 @@ impl PyObject { typ: PyTypeRef, vm: &VirtualMachine, ) -> PyResult> { + // Check HAS_WEAKREF flag first + if !self + .class() + .slots + .flags + .has_feature(crate::types::PyTypeFlags::HAS_WEAKREF) + { + return Err(vm.new_type_error(format!( + "cannot create weak reference to '{}' object", + self.class().name() + ))); + } let dict = if typ .slots .flags @@ -1180,8 +1451,8 @@ impl PyObject { } #[inline(always)] - const fn instance_dict(&self) -> Option<&InstanceDict> { - self.0.dict.as_ref() + fn instance_dict(&self) -> Option<&InstanceDict> { + self.0.ext_ref().and_then(|ext| ext.dict.as_ref()) } #[inline(always)] @@ -1396,11 +1667,11 @@ impl PyObject { } pub(crate) fn get_slot(&self, offset: usize) -> Option { - self.0.slots[offset].read().clone() + self.0.ext_ref().unwrap().slots[offset].read().clone() } pub(crate) fn set_slot(&self, offset: usize, value: Option) { - *self.0.slots[offset].write() = value; + *self.0.ext_ref().unwrap().slots[offset].write() = value; } /// _PyObject_GC_IS_TRACKED @@ -1486,10 +1757,32 @@ impl PyObject { unsafe { clear_fn(ptr, &mut result) }; } - // 2. Clear member slots (subtype_clear) - for slot in obj.0.slots.iter() { - if let Some(val) = slot.write().take() { - result.push(val); + // 2. Clear dict and member slots (subtype_clear) + // Detach the dict via Py_CLEAR(*_PyObject_GetDictPtr(self)) — NULL + // the pointer without clearing dict contents. The dict may still be + // referenced by other live objects (e.g. function.__globals__). + let (flags, member_count) = obj.0.read_type_flags(); + let has_ext = flags.has_feature(crate::types::PyTypeFlags::HAS_DICT) || member_count > 0; + if has_ext { + let has_weakref = flags.has_feature(crate::types::PyTypeFlags::HAS_WEAKREF); + let offset = if has_weakref { + WEAKREF_OFFSET + EXT_OFFSET + } else { + EXT_OFFSET + }; + let self_addr = (ptr as *const u8).addr(); + let ext_ptr = + core::ptr::with_exposed_provenance_mut::(self_addr.wrapping_sub(offset)); + let ext = unsafe { &mut *ext_ptr }; + if let Some(old_dict) = ext.dict.take() { + // Get the dict ref before dropping InstanceDict + let dict_ref = old_dict.into_inner(); + result.push(dict_ref.into()); + } + for slot in ext.slots.iter() { + if let Some(val) = slot.write().take() { + result.push(val); + } } } @@ -1513,7 +1806,11 @@ impl PyObject { /// Check if this object has clear capability (tp_clear) // Py_TPFLAGS_HAVE_GC types have tp_clear pub fn gc_has_clear(&self) -> bool { - self.0.vtable.clear.is_some() || self.0.dict.is_some() || !self.0.slots.is_empty() + self.0.vtable.clear.is_some() + || self + .0 + .ext_ref() + .is_some_and(|ext| ext.dict.is_some() || !ext.slots.is_empty()) } } @@ -1881,9 +2178,9 @@ impl PyRef { let has_dict = dict.is_some(); let is_heaptype = typ.heaptype_ext.is_some(); - // Try to reuse from freelist (exact type only, no dict, no heaptype) + // Try to reuse from freelist (no dict, no heaptype) let cached = if !has_dict && !is_heaptype { - unsafe { T::freelist_pop() } + unsafe { T::freelist_pop(&payload) } } else { None }; @@ -1895,14 +2192,19 @@ impl PyRef { (*inner).gc_bits.store(0, Ordering::Relaxed); core::ptr::drop_in_place(&mut (*inner).payload); core::ptr::write(&mut (*inner).payload, payload); - // typ, vtable, slots are preserved; dict is None, weak_list was - // cleared by drop_slow_inner before freelist push + // Freelist only stores exact base types (push-side filter), + // but subtypes sharing the same Rust payload (e.g. structseq) + // may pop entries. Update typ if it differs. + let cached_typ: *const Py = &*(*inner).typ; + if core::ptr::eq(cached_typ, &*typ) { + drop(typ); + } else { + let _old = (*inner).typ.swap(typ); + } } - // Drop the caller's typ since the cached object already holds one - drop(typ); unsafe { NonNull::new_unchecked(inner.cast::>()) } } else { - let inner = Box::into_raw(PyInner::new(payload, typ, dict)); + let inner = PyInner::new(payload, typ, dict); unsafe { NonNull::new_unchecked(inner.cast::>()) } }; @@ -2122,34 +2424,64 @@ pub(crate) fn init_type_hierarchy() -> (PyTypeRef, PyTypeRef, PyTypeRef) { heaptype_ext: None, tp_version_tag: core::sync::atomic::AtomicU32::new(0), }; - let type_type_ptr = Box::into_raw(Box::new(partially_init!( - PyInner:: { - ref_count: RefCount::new(), - vtable: PyObjVTable::of::(), - gc_bits: Radium::new(0), - gc_generation: Radium::new(GC_UNTRACKED), - gc_pointers: Pointers::new(), - dict: None, - weak_list: WeakRefList::new(), - payload: type_payload, - slots: Box::new([]), - }, - Uninit { typ } - ))); - let object_type_ptr = Box::into_raw(Box::new(partially_init!( - PyInner:: { - ref_count: RefCount::new(), - vtable: PyObjVTable::of::(), - gc_bits: Radium::new(0), - gc_generation: Radium::new(GC_UNTRACKED), - gc_pointers: Pointers::new(), - dict: None, - weak_list: WeakRefList::new(), - payload: object_payload, - slots: Box::new([]), - }, - Uninit { typ }, - ))); + // Both type_type and object_type are instances of `type`, which has + // HAS_DICT and HAS_WEAKREF, so they need both ObjExt and WeakRefList prefixes. + // Layout: [ObjExt][WeakRefList][PyInner] + let alloc_type_with_prefixes = || -> *mut MaybeUninit> { + let inner_layout = core::alloc::Layout::new::>>(); + let ext_layout = core::alloc::Layout::new::(); + let weakref_layout = core::alloc::Layout::new::(); + + let (layout, weakref_offset) = ext_layout.extend(weakref_layout).unwrap(); + let (combined, inner_offset) = layout.extend(inner_layout).unwrap(); + let combined = combined.pad_to_align(); + + let alloc_ptr = unsafe { alloc::alloc::alloc(combined) }; + if alloc_ptr.is_null() { + alloc::alloc::handle_alloc_error(combined); + } + alloc_ptr.expose_provenance(); + + unsafe { + let ext_ptr = alloc_ptr as *mut ObjExt; + ext_ptr.write(ObjExt::new(None, 0)); + + let weakref_ptr = alloc_ptr.add(weakref_offset) as *mut WeakRefList; + weakref_ptr.write(WeakRefList::new()); + + alloc_ptr.add(inner_offset) as *mut MaybeUninit> + } + }; + + let type_type_ptr = alloc_type_with_prefixes(); + unsafe { + type_type_ptr.write(partially_init!( + PyInner:: { + ref_count: RefCount::new(), + vtable: PyObjVTable::of::(), + gc_bits: Radium::new(0), + gc_generation: Radium::new(GC_UNTRACKED), + gc_pointers: Pointers::new(), + payload: type_payload, + }, + Uninit { typ } + )); + } + + let object_type_ptr = alloc_type_with_prefixes(); + unsafe { + object_type_ptr.write(partially_init!( + PyInner:: { + ref_count: RefCount::new(), + vtable: PyObjVTable::of::(), + gc_bits: Radium::new(0), + gc_generation: Radium::new(GC_UNTRACKED), + gc_pointers: Pointers::new(), + payload: object_payload, + }, + Uninit { typ }, + )); + } let object_type_ptr = object_type_ptr as *mut PyInner; let type_type_ptr = type_type_ptr as *mut PyInner; diff --git a/crates/vm/src/object/ext.rs b/crates/vm/src/object/ext.rs index 0fd251499f1..e39d1c7765f 100644 --- a/crates/vm/src/object/ext.rs +++ b/crates/vm/src/object/ext.rs @@ -289,8 +289,12 @@ impl fmt::Debug for PyAtomicRef { impl From> for PyAtomicRef { fn from(pyref: PyRef) -> Self { let py = PyRef::leak(pyref); + let ptr = py as *const _ as *mut u8; + // Expose provenance so we can re-derive via with_exposed_provenance + // without Stacked Borrows tag restrictions during bootstrap + ptr.expose_provenance(); Self { - inner: Radium::new(py as *const _ as *mut _), + inner: Radium::new(ptr), _phantom: Default::default(), } } @@ -311,6 +315,14 @@ impl Deref for PyAtomicRef { } impl PyAtomicRef { + /// Load the raw pointer without creating a reference. + /// Avoids Stacked Borrows retag, safe for use during bootstrap + /// when type objects have self-referential pointers being mutated. + #[inline(always)] + pub(super) fn load_raw(&self) -> *const Py { + self.inner.load(Ordering::Relaxed).cast::>() + } + /// # Safety /// The caller is responsible to keep the returned PyRef alive /// until no more reference can be used via PyAtomicRef::deref() @@ -343,11 +355,19 @@ impl From>> for PyAtomicRef> { impl PyAtomicRef> { pub fn deref(&self) -> Option<&Py> { - unsafe { self.inner.load(Ordering::Relaxed).cast::>().as_ref() } + self.deref_ordering(Ordering::Relaxed) + } + + pub fn deref_ordering(&self, ordering: Ordering) -> Option<&Py> { + unsafe { self.inner.load(ordering).cast::>().as_ref() } } pub fn to_owned(&self) -> Option> { - self.deref().map(|x| x.to_owned()) + self.to_owned_ordering(Ordering::Relaxed) + } + + pub fn to_owned_ordering(&self, ordering: Ordering) -> Option> { + self.deref_ordering(ordering).map(|x| x.to_owned()) } /// # Safety @@ -429,16 +449,19 @@ impl From> for PyAtomicRef> { impl PyAtomicRef> { pub fn deref(&self) -> Option<&PyObject> { - unsafe { - self.inner - .load(Ordering::Relaxed) - .cast::() - .as_ref() - } + self.deref_ordering(Ordering::Relaxed) + } + + pub fn deref_ordering(&self, ordering: Ordering) -> Option<&PyObject> { + unsafe { self.inner.load(ordering).cast::().as_ref() } } pub fn to_owned(&self) -> Option { - self.deref().map(|x| x.to_owned()) + self.to_owned_ordering(Ordering::Relaxed) + } + + pub fn to_owned_ordering(&self, ordering: Ordering) -> Option { + self.deref_ordering(ordering).map(|x| x.to_owned()) } /// # Safety diff --git a/crates/vm/src/object/payload.rs b/crates/vm/src/object/payload.rs index 1af954505f7..a615123c680 100644 --- a/crates/vm/src/object/payload.rs +++ b/crates/vm/src/object/payload.rs @@ -57,10 +57,11 @@ pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { /// Try to push a dead object onto this type's freelist for reuse. /// Returns true if the object was stored (caller must NOT free the memory). + /// Called before tp_clear, so the payload is still intact. /// /// # Safety - /// `obj` must be a valid pointer to a `PyInner` with refcount 0, - /// after `drop_slow_inner` and `tp_clear` have already run. + /// `obj` must be a valid pointer to a `PyInner` with refcount 0. + /// The payload is still initialized and can be read for bucket selection. #[inline] unsafe fn freelist_push(_obj: *mut PyObject) -> bool { false @@ -75,7 +76,7 @@ pub trait PyPayload: MaybeTraverse + PyThreadingConstraint + Sized + 'static { /// whose payload is still initialized from a previous allocation. The caller /// will drop and overwrite `payload` before reuse. #[inline] - unsafe fn freelist_pop() -> Option> { + unsafe fn freelist_pop(_payload: &Self) -> Option> { None } diff --git a/crates/vm/src/object/traverse_object.rs b/crates/vm/src/object/traverse_object.rs index 3f88c6b7481..f5614b3502a 100644 --- a/crates/vm/src/object/traverse_object.rs +++ b/crates/vm/src/object/traverse_object.rs @@ -65,9 +65,11 @@ unsafe impl Traverse for PyInner { let typ_obj: &PyObject = unsafe { &*(typ as *const _ as *const PyObject) }; tracer_fn(typ_obj); } - self.dict.traverse(tracer_fn); - // weak_list is inline atomic pointers, no heap allocation, no trace - self.slots.traverse(tracer_fn); + // Traverse ObjExt prefix fields (dict and slots) if present + if let Some(ext) = self.ext_ref() { + ext.dict.traverse(tracer_fn); + ext.slots.traverse(tracer_fn); + } if let Some(f) = self.vtable.trace { unsafe { @@ -87,9 +89,11 @@ unsafe impl Traverse for PyInner { let typ_obj: &PyObject = unsafe { &*(typ as *const _ as *const PyObject) }; tracer_fn(typ_obj); } - self.dict.traverse(tracer_fn); - // weak_list is inline atomic pointers, no heap allocation, no trace - self.slots.traverse(tracer_fn); + // Traverse ObjExt prefix fields (dict and slots) if present + if let Some(ext) = self.ext_ref() { + ext.dict.traverse(tracer_fn); + ext.slots.traverse(tracer_fn); + } T::try_traverse(&self.payload, tracer_fn); } } diff --git a/crates/vm/src/ospath.rs b/crates/vm/src/ospath.rs index 00195460ea3..d3123a87acb 100644 --- a/crates/vm/src/ospath.rs +++ b/crates/vm/src/ospath.rs @@ -86,7 +86,7 @@ impl PathConverter { .class() .is(crate::builtins::bool_::PyBool::static_type()) { - crate::stdlib::warnings::warn( + crate::stdlib::_warnings::warn( vm.ctx.exceptions.runtime_warning, "bool is used as a file descriptor".to_owned(), 1, diff --git a/crates/vm/src/protocol/callable.rs b/crates/vm/src/protocol/callable.rs index cecb9431fbb..6ff988abbe6 100644 --- a/crates/vm/src/protocol/callable.rs +++ b/crates/vm/src/protocol/callable.rs @@ -146,6 +146,14 @@ pub(crate) enum TraceEvent { } impl TraceEvent { + /// Whether sys.settrace receives this event. + fn is_trace_event(&self) -> bool { + matches!( + self, + Self::Call | Self::Return | Self::Exception | Self::Line | Self::Opcode + ) + } + /// Whether sys.setprofile receives this event. /// In legacy_tracing.c, profile callbacks are only registered for /// PY_RETURN, PY_UNWIND, C_CALL, C_RETURN, C_RAISE. @@ -211,6 +219,7 @@ impl VirtualMachine { return Ok(None); } + let is_trace_event = event.is_trace_event(); let is_profile_event = event.is_profile_event(); let is_opcode_event = event.is_opcode_event(); @@ -231,7 +240,7 @@ impl VirtualMachine { // temporarily disable tracing, during the call to the // tracing function itself. - if !self.is_none(&trace_func) { + if is_trace_event && !self.is_none(&trace_func) { self.use_tracing.set(false); let res = trace_func.call(args.clone(), self); self.use_tracing.set(true); diff --git a/crates/vm/src/protocol/number.rs b/crates/vm/src/protocol/number.rs index 542afce2c6c..36dbd5b8843 100644 --- a/crates/vm/src/protocol/number.rs +++ b/crates/vm/src/protocol/number.rs @@ -11,7 +11,7 @@ use crate::{ common::int::{BytesToIntError, bytes_to_int}, function::ArgBytesLike, object::{Traverse, TraverseFn}, - stdlib::warnings, + stdlib::_warnings, }; pub type PyNumberUnaryFunc = fn(PyNumber<'_>, &VirtualMachine) -> PyResult; @@ -59,7 +59,7 @@ impl PyObject { } else if let Some(i) = self.number().int(vm).or_else(|| self.try_index_opt(vm)) { i } else if let Ok(Some(f)) = vm.get_special_method(self, identifier!(vm, __trunc__)) { - warnings::warn( + _warnings::warn( vm.ctx.exceptions.deprecation_warning, "The delegation of int() to __trunc__ is deprecated.".to_owned(), 1, @@ -589,7 +589,7 @@ impl PyNumber<'_> { let ret_class = ret.class().to_owned(); if let Some(ret) = ret.downcast_ref::() { - warnings::warn( + _warnings::warn( vm.ctx.exceptions.deprecation_warning, format!( "__int__ returned non-int (type {ret_class}). \ @@ -622,7 +622,7 @@ impl PyNumber<'_> { let ret_class = ret.class().to_owned(); if let Some(ret) = ret.downcast_ref::() { - warnings::warn( + _warnings::warn( vm.ctx.exceptions.deprecation_warning, format!( "__index__ returned non-int (type {ret_class}). \ @@ -655,7 +655,7 @@ impl PyNumber<'_> { let ret_class = ret.class().to_owned(); if let Some(ret) = ret.downcast_ref::() { - warnings::warn( + _warnings::warn( vm.ctx.exceptions.deprecation_warning, format!( "__float__ returned non-float (type {ret_class}). \ diff --git a/crates/vm/src/signal.rs b/crates/vm/src/signal.rs index 87c4fe2749f..3caf8cb8e30 100644 --- a/crates/vm/src/signal.rs +++ b/crates/vm/src/signal.rs @@ -91,6 +91,11 @@ pub(crate) fn set_triggered() { ANY_TRIGGERED.store(true, Ordering::Release); } +#[inline(always)] +pub(crate) fn is_triggered() -> bool { + ANY_TRIGGERED.load(Ordering::Relaxed) +} + /// Reset all signal trigger state after fork in child process. /// Stale triggers from the parent must not fire in the child. #[cfg(unix)] @@ -116,7 +121,7 @@ pub fn assert_in_range(signum: i32, vm: &VirtualMachine) -> PyResult<()> { #[allow(dead_code)] #[cfg(all(not(target_arch = "wasm32"), feature = "host_env"))] pub fn set_interrupt_ex(signum: i32, vm: &VirtualMachine) -> PyResult<()> { - use crate::stdlib::signal::_signal::{SIG_DFL, SIG_IGN, run_signal}; + use crate::stdlib::_signal::_signal::{SIG_DFL, SIG_IGN, run_signal}; assert_in_range(signum, vm)?; match signum as usize { diff --git a/crates/vm/src/stdlib/_abc.rs b/crates/vm/src/stdlib/_abc.rs index 5657cda9865..28b3399ec4b 100644 --- a/crates/vm/src/stdlib/_abc.rs +++ b/crates/vm/src/stdlib/_abc.rs @@ -77,7 +77,7 @@ mod _abc { let impl_obj = cls.get_attr("_abc_impl", vm)?; impl_obj .downcast::() - .map_err(|_| vm.new_type_error("_abc_impl is set to a wrong type".to_owned())) + .map_err(|_| vm.new_type_error("_abc_impl is set to a wrong type")) } /// Check if obj is in the weak set @@ -152,12 +152,10 @@ mod _abc { while let PyIterReturn::Return(item) = iter.next(vm)? { let tuple: PyTupleRef = item .downcast() - .map_err(|_| vm.new_type_error("items() returned non-tuple".to_owned()))?; + .map_err(|_| vm.new_type_error("items() returned non-tuple"))?; let elements = tuple.as_slice(); if elements.len() != 2 { - return Err( - vm.new_type_error("items() returned item which size is not 2".to_owned()) - ); + return Err(vm.new_type_error("items() returned item which size is not 2")); } let key = &elements[0]; let value = &elements[1]; @@ -174,7 +172,7 @@ mod _abc { let bases: PyTupleRef = cls .get_attr("__bases__", vm)? .downcast() - .map_err(|_| vm.new_type_error("__bases__ is not a tuple".to_owned()))?; + .map_err(|_| vm.new_type_error("__bases__ is not a tuple"))?; for base in bases.iter() { if let Ok(base_abstracts) = base.get_attr("__abstractmethods__", vm) { @@ -220,7 +218,7 @@ mod _abc { ) -> PyResult { // Type check if !subclass.class().fast_issubclass(vm.ctx.types.type_type) { - return Err(vm.new_type_error("Can only register classes".to_owned())); + return Err(vm.new_type_error("Can only register classes")); } // Check if already a subclass @@ -230,7 +228,7 @@ mod _abc { // Check for cycles if cls.is_subclass(&subclass, vm)? { - return Err(vm.new_runtime_error("Refusing to create an inheritance cycle".to_owned())); + return Err(vm.new_runtime_error("Refusing to create an inheritance cycle")); } // Add to registry @@ -328,7 +326,7 @@ mod _abc { ) -> PyResult { // Type check if !subclass.class().fast_issubclass(vm.ctx.types.type_type) { - return Err(vm.new_type_error("issubclass() arg 1 must be a class".to_owned())); + return Err(vm.new_type_error("issubclass() arg 1 must be a class")); } let impl_data = get_impl(&cls, vm)?; @@ -373,11 +371,11 @@ mod _abc { let subclass_type: PyTypeRef = subclass .clone() .downcast() - .map_err(|_| vm.new_type_error("expected a type object".to_owned()))?; + .map_err(|_| vm.new_type_error("expected a type object"))?; let cls_type: PyTypeRef = cls .clone() .downcast() - .map_err(|_| vm.new_type_error("expected a type object".to_owned()))?; + .map_err(|_| vm.new_type_error("expected a type object"))?; if subclass_type.fast_issubclass(&cls_type) { add_to_weak_set(&impl_data.cache, &subclass, vm)?; return Ok(true); @@ -392,7 +390,7 @@ mod _abc { let subclasses: PyRef = vm .call_method(&cls, "__subclasses__", ())? .downcast() - .map_err(|_| vm.new_type_error("__subclasses__() must return a list".to_owned()))?; + .map_err(|_| vm.new_type_error("__subclasses__() must return a list"))?; for scls in subclasses.borrow_vec().iter() { if subclass.is_subclass(scls, vm)? { diff --git a/crates/vm/src/stdlib/ast.rs b/crates/vm/src/stdlib/_ast.rs similarity index 99% rename from crates/vm/src/stdlib/ast.rs rename to crates/vm/src/stdlib/_ast.rs index 1a4553c5d21..73819e257c1 100644 --- a/crates/vm/src/stdlib/ast.rs +++ b/crates/vm/src/stdlib/_ast.rs @@ -9,8 +9,8 @@ pub(crate) use python::_ast::module_def; mod pyast; use crate::builtins::{PyInt, PyStr}; -use crate::stdlib::ast::module::{Mod, ModFunctionType, ModInteractive}; -use crate::stdlib::ast::node::BoxedSlice; +use crate::stdlib::_ast::module::{Mod, ModFunctionType, ModInteractive}; +use crate::stdlib::_ast::node::BoxedSlice; use crate::{ AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyRefExact, PyResult, TryFromObject, VirtualMachine, @@ -398,7 +398,7 @@ pub(crate) fn parse( #[cfg(feature = "parser")] pub(crate) fn wrap_interactive(vm: &VirtualMachine, module_obj: PyObjectRef) -> PyResult { if !module_obj.class().is(pyast::NodeModModule::static_type()) { - return Err(vm.new_type_error("expected Module node".to_owned())); + return Err(vm.new_type_error("expected Module node")); } let body = get_node_field(vm, &module_obj, "body", "Module")?; let node = NodeAst diff --git a/crates/vm/src/stdlib/ast/argument.rs b/crates/vm/src/stdlib/_ast/argument.rs similarity index 100% rename from crates/vm/src/stdlib/ast/argument.rs rename to crates/vm/src/stdlib/_ast/argument.rs diff --git a/crates/vm/src/stdlib/ast/basic.rs b/crates/vm/src/stdlib/_ast/basic.rs similarity index 100% rename from crates/vm/src/stdlib/ast/basic.rs rename to crates/vm/src/stdlib/_ast/basic.rs diff --git a/crates/vm/src/stdlib/ast/constant.rs b/crates/vm/src/stdlib/_ast/constant.rs similarity index 100% rename from crates/vm/src/stdlib/ast/constant.rs rename to crates/vm/src/stdlib/_ast/constant.rs diff --git a/crates/vm/src/stdlib/ast/elif_else_clause.rs b/crates/vm/src/stdlib/_ast/elif_else_clause.rs similarity index 100% rename from crates/vm/src/stdlib/ast/elif_else_clause.rs rename to crates/vm/src/stdlib/_ast/elif_else_clause.rs diff --git a/crates/vm/src/stdlib/ast/exception.rs b/crates/vm/src/stdlib/_ast/exception.rs similarity index 100% rename from crates/vm/src/stdlib/ast/exception.rs rename to crates/vm/src/stdlib/_ast/exception.rs diff --git a/crates/vm/src/stdlib/ast/expression.rs b/crates/vm/src/stdlib/_ast/expression.rs similarity index 99% rename from crates/vm/src/stdlib/ast/expression.rs rename to crates/vm/src/stdlib/_ast/expression.rs index 5e55b7b676b..cbc47dde9fb 100644 --- a/crates/vm/src/stdlib/ast/expression.rs +++ b/crates/vm/src/stdlib/_ast/expression.rs @@ -1,5 +1,5 @@ use super::*; -use crate::stdlib::ast::{ +use crate::stdlib::_ast::{ argument::{merge_function_call_arguments, split_function_call_arguments}, constant::Constant, string::JoinedStr, @@ -463,9 +463,7 @@ impl Node for ast::ExprDict { get_node_field(vm, &object, "values", "Dict")?, )?; if keys.len() != values.len() { - return Err(vm.new_value_error( - "Dict doesn't have the same number of keys as values".to_owned(), - )); + return Err(vm.new_value_error("Dict doesn't have the same number of keys as values")); } let items = keys .into_iter() diff --git a/crates/vm/src/stdlib/ast/module.rs b/crates/vm/src/stdlib/_ast/module.rs similarity index 99% rename from crates/vm/src/stdlib/ast/module.rs rename to crates/vm/src/stdlib/_ast/module.rs index cfedba606b0..b4c2468d33b 100644 --- a/crates/vm/src/stdlib/ast/module.rs +++ b/crates/vm/src/stdlib/_ast/module.rs @@ -1,5 +1,5 @@ use super::*; -use crate::stdlib::ast::type_ignore::TypeIgnore; +use crate::stdlib::_ast::type_ignore::TypeIgnore; use rustpython_compiler_core::SourceFile; /// Represents the different types of Python module structures. diff --git a/crates/vm/src/stdlib/ast/node.rs b/crates/vm/src/stdlib/_ast/node.rs similarity index 100% rename from crates/vm/src/stdlib/ast/node.rs rename to crates/vm/src/stdlib/_ast/node.rs diff --git a/crates/vm/src/stdlib/ast/operator.rs b/crates/vm/src/stdlib/_ast/operator.rs similarity index 100% rename from crates/vm/src/stdlib/ast/operator.rs rename to crates/vm/src/stdlib/_ast/operator.rs diff --git a/crates/vm/src/stdlib/ast/other.rs b/crates/vm/src/stdlib/_ast/other.rs similarity index 100% rename from crates/vm/src/stdlib/ast/other.rs rename to crates/vm/src/stdlib/_ast/other.rs diff --git a/crates/vm/src/stdlib/ast/parameter.rs b/crates/vm/src/stdlib/_ast/parameter.rs similarity index 100% rename from crates/vm/src/stdlib/ast/parameter.rs rename to crates/vm/src/stdlib/_ast/parameter.rs diff --git a/crates/vm/src/stdlib/ast/pattern.rs b/crates/vm/src/stdlib/_ast/pattern.rs similarity index 99% rename from crates/vm/src/stdlib/ast/pattern.rs rename to crates/vm/src/stdlib/_ast/pattern.rs index 3b665a95b55..621c849e812 100644 --- a/crates/vm/src/stdlib/ast/pattern.rs +++ b/crates/vm/src/stdlib/_ast/pattern.rs @@ -365,9 +365,7 @@ impl Node for ast::PatternMatchClass { get_node_field(vm, &object, "kwd_patterns", "MatchClass")?, )?; if kwd_attrs.0.len() != kwd_patterns.0.len() { - return Err(vm.new_value_error( - "MatchClass has mismatched kwd_attrs and kwd_patterns".to_owned(), - )); + return Err(vm.new_value_error("MatchClass has mismatched kwd_attrs and kwd_patterns")); } let (patterns, keywords) = merge_pattern_match_class(patterns, kwd_attrs, kwd_patterns); diff --git a/crates/vm/src/stdlib/ast/pyast.rs b/crates/vm/src/stdlib/_ast/pyast.rs similarity index 100% rename from crates/vm/src/stdlib/ast/pyast.rs rename to crates/vm/src/stdlib/_ast/pyast.rs diff --git a/crates/vm/src/stdlib/ast/python.rs b/crates/vm/src/stdlib/_ast/python.rs similarity index 99% rename from crates/vm/src/stdlib/ast/python.rs rename to crates/vm/src/stdlib/_ast/python.rs index c8883212a0c..5c97d759934 100644 --- a/crates/vm/src/stdlib/ast/python.rs +++ b/crates/vm/src/stdlib/_ast/python.rs @@ -12,7 +12,7 @@ pub(crate) mod _ast { class::{PyClassImpl, StaticType}, common::wtf8::Wtf8, function::{FuncArgs, KwArgs, PyMethodDef, PyMethodFlags}, - stdlib::ast::repr, + stdlib::_ast::repr, types::{Constructor, Initializer}, warn, }; @@ -118,7 +118,7 @@ pub(crate) mod _ast { pub(crate) fn ast_replace(zelf: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { if !args.args.is_empty() { - return Err(vm.new_type_error("__replace__() takes no positional arguments".to_owned())); + return Err(vm.new_type_error("__replace__() takes no positional arguments")); } let cls = zelf.class(); @@ -225,7 +225,7 @@ pub(crate) mod _ast { .map(|(key, value)| { let key = key .downcast::() - .map_err(|_| vm.new_type_error("keywords must be strings".to_owned()))?; + .map_err(|_| vm.new_type_error("keywords must be strings"))?; Ok((key.as_str().to_owned(), value)) }) .collect::>>()?; @@ -450,7 +450,7 @@ Support for arbitrary keyword arguments is deprecated and will be removed in Pyt let ast_type = module .get_attr("AST", vm)? .downcast::() - .map_err(|_| vm.new_type_error("AST is not a type".to_owned()))?; + .map_err(|_| vm.new_type_error("AST is not a type"))?; let ctx = &vm.ctx; let empty_tuple = ctx.empty_tuple.clone(); ast_type.set_str_attr("_fields", empty_tuple.clone(), ctx); diff --git a/crates/vm/src/stdlib/ast/repr.rs b/crates/vm/src/stdlib/_ast/repr.rs similarity index 99% rename from crates/vm/src/stdlib/ast/repr.rs rename to crates/vm/src/stdlib/_ast/repr.rs index 47fceb2386e..2897447fbec 100644 --- a/crates/vm/src/stdlib/ast/repr.rs +++ b/crates/vm/src/stdlib/_ast/repr.rs @@ -2,7 +2,7 @@ use crate::{ AsObject, PyObjectRef, PyResult, VirtualMachine, builtins::{PyList, PyTuple}, class::PyClassImpl, - stdlib::ast::NodeAst, + stdlib::_ast::NodeAst, }; use rustpython_common::wtf8::Wtf8Buf; diff --git a/crates/vm/src/stdlib/ast/statement.rs b/crates/vm/src/stdlib/_ast/statement.rs similarity index 99% rename from crates/vm/src/stdlib/ast/statement.rs rename to crates/vm/src/stdlib/_ast/statement.rs index 8b6ceb490a1..cda34da40a3 100644 --- a/crates/vm/src/stdlib/ast/statement.rs +++ b/crates/vm/src/stdlib/_ast/statement.rs @@ -1,5 +1,5 @@ use super::*; -use crate::stdlib::ast::argument::{merge_class_def_args, split_class_def_args}; +use crate::stdlib::_ast::argument::{merge_class_def_args, split_class_def_args}; use rustpython_compiler_core::SourceFile; // sum @@ -1112,11 +1112,10 @@ impl Node for ast::StmtImportFrom { let int: PyRef = obj.try_into_value(vm)?; let value: i64 = int.try_to_primitive(vm)?; if value < 0 { - return Err(vm.new_value_error("Negative ImportFrom level".to_owned())); + return Err(vm.new_value_error("Negative ImportFrom level")); } - u32::try_from(value).map_err(|_| { - vm.new_overflow_error("ImportFrom level out of range".to_owned()) - }) + u32::try_from(value) + .map_err(|_| vm.new_overflow_error("ImportFrom level out of range")) }) .transpose()? .unwrap_or(0), diff --git a/crates/vm/src/stdlib/ast/string.rs b/crates/vm/src/stdlib/_ast/string.rs similarity index 99% rename from crates/vm/src/stdlib/ast/string.rs rename to crates/vm/src/stdlib/_ast/string.rs index 4b6a6e8489f..24cae476694 100644 --- a/crates/vm/src/stdlib/ast/string.rs +++ b/crates/vm/src/stdlib/_ast/string.rs @@ -754,9 +754,7 @@ fn template_part_to_element( match part { TemplateStrPart::Constant(constant) => { let ConstantLiteral::Str { value, .. } = constant.value else { - return Err( - vm.new_type_error("TemplateStr constant values must be strings".to_owned()) - ); + return Err(vm.new_type_error("TemplateStr constant values must be strings")); }; Ok(ast::InterpolatedStringElement::Literal( ast::InterpolatedStringLiteralElement { diff --git a/crates/vm/src/stdlib/ast/type_ignore.rs b/crates/vm/src/stdlib/_ast/type_ignore.rs similarity index 100% rename from crates/vm/src/stdlib/ast/type_ignore.rs rename to crates/vm/src/stdlib/_ast/type_ignore.rs diff --git a/crates/vm/src/stdlib/ast/type_parameters.rs b/crates/vm/src/stdlib/_ast/type_parameters.rs similarity index 100% rename from crates/vm/src/stdlib/ast/type_parameters.rs rename to crates/vm/src/stdlib/_ast/type_parameters.rs diff --git a/crates/vm/src/stdlib/ast/validate.rs b/crates/vm/src/stdlib/_ast/validate.rs similarity index 96% rename from crates/vm/src/stdlib/ast/validate.rs rename to crates/vm/src/stdlib/_ast/validate.rs index ea5c2be840c..9957fe4ee39 100644 --- a/crates/vm/src/stdlib/ast/validate.rs +++ b/crates/vm/src/stdlib/_ast/validate.rs @@ -25,7 +25,7 @@ fn validate_name(vm: &VirtualMachine, name: &ast::name::Name) -> PyResult<()> { fn validate_comprehension(vm: &VirtualMachine, gens: &[ast::Comprehension]) -> PyResult<()> { if gens.is_empty() { - return Err(vm.new_value_error("comprehension with no generators".to_owned())); + return Err(vm.new_value_error("comprehension with no generators")); } for comp in gens { validate_expr(vm, &comp.target, ast::ExprContext::Store)?; @@ -133,31 +133,25 @@ fn validate_pattern_match_value(vm: &VirtualMachine, expr: &ast::Expr) -> PyResu ast::Expr::Attribute(_) => Ok(()), ast::Expr::UnaryOp(op) => match &*op.operand { ast::Expr::NumberLiteral(_) => Ok(()), - _ => Err(vm.new_value_error( - "patterns may only match literals and attribute lookups".to_owned(), - )), + _ => Err(vm.new_value_error("patterns may only match literals and attribute lookups")), }, ast::Expr::BinOp(bin) => match (&*bin.left, &*bin.right) { (ast::Expr::NumberLiteral(_), ast::Expr::NumberLiteral(_)) => Ok(()), - _ => Err(vm.new_value_error( - "patterns may only match literals and attribute lookups".to_owned(), - )), + _ => Err(vm.new_value_error("patterns may only match literals and attribute lookups")), }, ast::Expr::FString(_) | ast::Expr::TString(_) => Ok(()), ast::Expr::BooleanLiteral(_) | ast::Expr::NoneLiteral(_) | ast::Expr::EllipsisLiteral(_) => { - Err(vm.new_value_error("unexpected constant inside of a literal pattern".to_owned())) + Err(vm.new_value_error("unexpected constant inside of a literal pattern")) } - _ => Err( - vm.new_value_error("patterns may only match literals and attribute lookups".to_owned()) - ), + _ => Err(vm.new_value_error("patterns may only match literals and attribute lookups")), } } fn validate_capture(vm: &VirtualMachine, name: &ast::Identifier) -> PyResult<()> { if name.as_str() == "_" { - return Err(vm.new_value_error("can't capture name '_' in patterns".to_owned())); + return Err(vm.new_value_error("can't capture name '_' in patterns")); } validate_name(vm, name.id()) } @@ -172,7 +166,7 @@ fn validate_pattern(vm: &VirtualMachine, pattern: &ast::Pattern, star_ok: bool) ast::Pattern::MatchMapping(mapping) => { if mapping.keys.len() != mapping.patterns.len() { return Err(vm.new_value_error( - "MatchMapping doesn't have the same number of keys as patterns".to_owned(), + "MatchMapping doesn't have the same number of keys as patterns", )); } if let Some(rest) = &mapping.rest { @@ -197,8 +191,7 @@ fn validate_pattern(vm: &VirtualMachine, pattern: &ast::Pattern, star_ok: bool) } _ => { return Err(vm.new_value_error( - "MatchClass cls field can only contain Name or Attribute nodes." - .to_owned(), + "MatchClass cls field can only contain Name or Attribute nodes.", )); } } @@ -214,7 +207,7 @@ fn validate_pattern(vm: &VirtualMachine, pattern: &ast::Pattern, star_ok: bool) } ast::Pattern::MatchStar(star) => { if !star_ok { - return Err(vm.new_value_error("can't use MatchStar here".to_owned())); + return Err(vm.new_value_error("can't use MatchStar here")); } if let Some(name) = &star.name { validate_capture(vm, name)?; @@ -230,7 +223,7 @@ fn validate_pattern(vm: &VirtualMachine, pattern: &ast::Pattern, star_ok: bool) Some(pattern) => { if match_as.name.is_none() { return Err(vm.new_value_error( - "MatchAs must specify a target name if a pattern is given".to_owned(), + "MatchAs must specify a target name if a pattern is given", )); } validate_pattern(vm, pattern, false) @@ -239,7 +232,7 @@ fn validate_pattern(vm: &VirtualMachine, pattern: &ast::Pattern, star_ok: bool) } ast::Pattern::MatchOr(match_or) => { if match_or.patterns.len() < 2 { - return Err(vm.new_value_error("MatchOr requires at least 2 patterns".to_owned())); + return Err(vm.new_value_error("MatchOr requires at least 2 patterns")); } validate_patterns(vm, &match_or.patterns, false) } @@ -342,13 +335,13 @@ fn validate_expr(vm: &VirtualMachine, expr: &ast::Expr, ctx: ast::ExprContext) - match expr { ast::Expr::BoolOp(op) => { if op.values.len() < 2 { - return Err(vm.new_value_error("BoolOp with less than 2 values".to_owned())); + return Err(vm.new_value_error("BoolOp with less than 2 values")); } validate_exprs(vm, &op.values, ast::ExprContext::Load, false) } ast::Expr::Named(named) => { if !matches!(&*named.target, ast::Expr::Name(_)) { - return Err(vm.new_type_error("NamedExpr target must be a Name".to_owned())); + return Err(vm.new_type_error("NamedExpr target must be a Name")); } validate_expr(vm, &named.value, ast::ExprContext::Load) } @@ -409,11 +402,11 @@ fn validate_expr(vm: &VirtualMachine, expr: &ast::Expr, ctx: ast::ExprContext) - } ast::Expr::Compare(compare) => { if compare.comparators.is_empty() { - return Err(vm.new_value_error("Compare with no comparators".to_owned())); + return Err(vm.new_value_error("Compare with no comparators")); } if compare.comparators.len() != compare.ops.len() { return Err(vm.new_value_error( - "Compare has a different number of comparators and operands".to_owned(), + "Compare has a different number of comparators and operands", )); } validate_exprs(vm, &compare.comparators, ast::ExprContext::Load, false)?; @@ -519,7 +512,7 @@ fn validate_stmt(vm: &VirtualMachine, stmt: &ast::Stmt) -> PyResult<()> { } ast::Stmt::AnnAssign(assign) => { if assign.simple && !matches!(&*assign.target, ast::Expr::Name(_)) { - return Err(vm.new_type_error("AnnAssign with simple non-Name target".to_owned())); + return Err(vm.new_type_error("AnnAssign with simple non-Name target")); } validate_expr(vm, &assign.target, ast::ExprContext::Store)?; if let Some(value) = &assign.value { @@ -529,7 +522,7 @@ fn validate_stmt(vm: &VirtualMachine, stmt: &ast::Stmt) -> PyResult<()> { } ast::Stmt::TypeAlias(alias) => { if !matches!(&*alias.name, ast::Expr::Name(_)) { - return Err(vm.new_type_error("TypeAlias with non-Name name".to_owned())); + return Err(vm.new_type_error("TypeAlias with non-Name name")); } validate_expr(vm, &alias.name, ast::ExprContext::Store)?; validate_type_params(vm, &alias.type_params)?; @@ -592,7 +585,7 @@ fn validate_stmt(vm: &VirtualMachine, stmt: &ast::Stmt) -> PyResult<()> { validate_expr(vm, cause, ast::ExprContext::Load)?; } } else if raise.cause.is_some() { - return Err(vm.new_value_error("Raise with cause but no exception".to_owned())); + return Err(vm.new_value_error("Raise with cause but no exception")); } Ok(()) } diff --git a/crates/vm/src/stdlib/codecs.rs b/crates/vm/src/stdlib/_codecs.rs similarity index 99% rename from crates/vm/src/stdlib/codecs.rs rename to crates/vm/src/stdlib/_codecs.rs index 6c37ee4c9f9..39ebb3599bd 100644 --- a/crates/vm/src/stdlib/codecs.rs +++ b/crates/vm/src/stdlib/_codecs.rs @@ -96,7 +96,7 @@ mod _codecs { vm: &VirtualMachine, ) -> PyResult<()> { if !handler.is_callable() { - return Err(vm.new_type_error("handler must be callable".to_owned())); + return Err(vm.new_type_error("handler must be callable")); } vm.state .codec_registry @@ -398,7 +398,7 @@ mod _codecs_windows { None => { // String contains surrogates - not encodable with mbcs return Err(vm.new_unicode_encode_error( - "'mbcs' codec can't encode character: surrogates not allowed".to_string(), + "'mbcs' codec can't encode character: surrogates not allowed", )); } }; @@ -584,7 +584,7 @@ mod _codecs_windows { None => { // String contains surrogates - not encodable with oem return Err(vm.new_unicode_encode_error( - "'oem' codec can't encode character: surrogates not allowed".to_string(), + "'oem' codec can't encode character: surrogates not allowed", )); } }; @@ -1052,7 +1052,7 @@ mod _codecs_windows { use crate::common::windows::ToWideString; if args.code_page < 0 { - return Err(vm.new_value_error("invalid code page number".to_owned())); + return Err(vm.new_value_error("invalid code page number")); } let errors = args.errors.as_ref().map(|s| s.as_str()).unwrap_or("strict"); let code_page = args.code_page as u32; @@ -1365,7 +1365,7 @@ mod _codecs_windows { use crate::common::wtf8::Wtf8Buf; if args.code_page < 0 { - return Err(vm.new_value_error("invalid code page number".to_owned())); + return Err(vm.new_value_error("invalid code page number")); } let errors = args.errors.as_ref().map(|s| s.as_str()).unwrap_or("strict"); let code_page = args.code_page as u32; diff --git a/crates/vm/src/stdlib/collections.rs b/crates/vm/src/stdlib/_collections.rs similarity index 99% rename from crates/vm/src/stdlib/collections.rs rename to crates/vm/src/stdlib/_collections.rs index 80f80e2d28f..2807e171777 100644 --- a/crates/vm/src/stdlib/collections.rs +++ b/crates/vm/src/stdlib/_collections.rs @@ -56,7 +56,7 @@ mod _collections { } #[pyclass( - flags(BASETYPE), + flags(BASETYPE, HAS_WEAKREF), with( Constructor, Initializer, diff --git a/crates/vm/src/stdlib/ctypes.rs b/crates/vm/src/stdlib/_ctypes.rs similarity index 99% rename from crates/vm/src/stdlib/ctypes.rs rename to crates/vm/src/stdlib/_ctypes.rs index 8ab52f3bcef..2534f6128e8 100644 --- a/crates/vm/src/stdlib/ctypes.rs +++ b/crates/vm/src/stdlib/_ctypes.rs @@ -930,7 +930,7 @@ pub(crate) mod _ctypes { let buffer = cdata.buffer.read(); if matches!(&*buffer, Cow::Borrowed(_)) { return Err(vm.new_value_error( - "Memory cannot be resized because this object doesn't own it".to_owned(), + "Memory cannot be resized because this object doesn't own it", )); } } diff --git a/crates/vm/src/stdlib/ctypes/array.rs b/crates/vm/src/stdlib/_ctypes/array.rs similarity index 99% rename from crates/vm/src/stdlib/ctypes/array.rs rename to crates/vm/src/stdlib/_ctypes/array.rs index 0672d0cbe80..568e2a4a0a9 100644 --- a/crates/vm/src/stdlib/ctypes/array.rs +++ b/crates/vm/src/stdlib/_ctypes/array.rs @@ -798,9 +798,10 @@ impl PyCArray { } else if let Ok(int_val) = value.try_index(vm) { (int_val.as_bigint().to_usize().unwrap_or(0), None) } else { - return Err(vm.new_type_error( - "bytes or integer address expected instead of {}".to_owned(), - )); + return Err(vm.new_type_error(format!( + "bytes or integer address expected instead of {} instance", + value.class().name() + ))); }; if offset + element_size <= buffer.len() { buffer[offset..offset + element_size].copy_from_slice(&ptr_val.to_ne_bytes()); diff --git a/crates/vm/src/stdlib/ctypes/base.rs b/crates/vm/src/stdlib/_ctypes/base.rs similarity index 99% rename from crates/vm/src/stdlib/ctypes/base.rs rename to crates/vm/src/stdlib/_ctypes/base.rs index 90137f2549d..0bfbe57bb04 100644 --- a/crates/vm/src/stdlib/ctypes/base.rs +++ b/crates/vm/src/stdlib/_ctypes/base.rs @@ -1433,8 +1433,7 @@ impl Constructor for PyCField { if !internal_use { return Err(vm.new_type_error( - "CField is not intended to be used directly; use it via Structure or Union fields" - .to_string(), + "CField is not intended to be used directly; use it via Structure or Union fields", )); } @@ -1493,11 +1492,11 @@ impl Constructor for PyCField { if let Some(bs) = bit_size_val { if bs < 0 { - return Err(vm.new_value_error("number of bits invalid for bit field".to_string())); + return Err(vm.new_value_error("number of bits invalid for bit field")); } let bo = bit_offset_val.unwrap_or(0); if bo < 0 { - return Err(vm.new_value_error("bit_offset must be >= 0".to_string())); + return Err(vm.new_value_error("bit_offset must be >= 0")); } let type_bits = byte_size * 8; if bo + bs > type_bits { diff --git a/crates/vm/src/stdlib/ctypes/function.rs b/crates/vm/src/stdlib/_ctypes/function.rs similarity index 99% rename from crates/vm/src/stdlib/ctypes/function.rs rename to crates/vm/src/stdlib/_ctypes/function.rs index bf6dcfad53d..e28fd91abbc 100644 --- a/crates/vm/src/stdlib/ctypes/function.rs +++ b/crates/vm/src/stdlib/_ctypes/function.rs @@ -232,7 +232,7 @@ fn convert_to_pointer(value: &PyObject, vm: &VirtualMachine) -> PyResult { @@ -559,8 +557,7 @@ impl Initializer for PyCSimpleType { // Validate _type_ is a single character if type_str.len() != 1 { return Err(vm.new_value_error( - "class must define a '_type_' attribute which must be a string of length 1" - .to_owned(), + "class must define a '_type_' attribute which must be a string of length 1", )); } diff --git a/crates/vm/src/stdlib/ctypes/structure.rs b/crates/vm/src/stdlib/_ctypes/structure.rs similarity index 98% rename from crates/vm/src/stdlib/ctypes/structure.rs rename to crates/vm/src/stdlib/_ctypes/structure.rs index 69d267f287e..c9ab9205601 100644 --- a/crates/vm/src/stdlib/ctypes/structure.rs +++ b/crates/vm/src/stdlib/_ctypes/structure.rs @@ -3,7 +3,7 @@ use crate::builtins::{PyList, PyStr, PyTuple, PyType, PyTypeRef, PyUtf8Str}; use crate::convert::ToPyObject; use crate::function::{FuncArgs, OptionalArg, PySetterValue}; use crate::protocol::{BufferDescriptor, PyBuffer, PyNumberMethods}; -use crate::stdlib::warnings; +use crate::stdlib::_warnings; use crate::types::{AsBuffer, AsNumber, Constructor, Initializer, SetAttr}; use crate::{AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine}; use alloc::borrow::Cow; @@ -259,7 +259,7 @@ impl PyCStructType { cls.name(), base_type_name, ); - warnings::warn(vm.ctx.exceptions.deprecation_warning, msg, 1, vm)?; + _warnings::warn(vm.ctx.exceptions.deprecation_warning, msg, 1, vm)?; } } @@ -310,9 +310,9 @@ impl PyCStructType { .ok_or_else(|| vm.new_type_error("_fields_ must contain tuples"))?; if field_tuple.len() < 2 { - return Err(vm.new_type_error( - "_fields_ tuple must have at least 2 elements (name, type)".to_string(), - )); + return Err( + vm.new_type_error("_fields_ tuple must have at least 2 elements (name, type)") + ); } let name = field_tuple @@ -428,9 +428,7 @@ impl PyCStructType { .try_int(vm)? .as_bigint() .to_u16() - .ok_or_else(|| { - vm.new_value_error("number of bits invalid for bit field".to_string()) - })?; + .ok_or_else(|| vm.new_value_error("number of bits invalid for bit field"))?; has_bitfield = true; let type_bits = (size * 8) as u16; @@ -534,9 +532,7 @@ impl PyCStructType { if let Some(stg_info) = cls.get_type_data::() && stg_info.is_final() { - return Err( - vm.new_attribute_error("Structure or union cannot contain itself".to_string()) - ); + return Err(vm.new_attribute_error("Structure or union cannot contain itself")); } // Store StgInfo with aligned size and total alignment diff --git a/crates/vm/src/stdlib/ctypes/union.rs b/crates/vm/src/stdlib/_ctypes/union.rs similarity index 97% rename from crates/vm/src/stdlib/ctypes/union.rs rename to crates/vm/src/stdlib/_ctypes/union.rs index 7526fa92eff..c1882141e98 100644 --- a/crates/vm/src/stdlib/ctypes/union.rs +++ b/crates/vm/src/stdlib/_ctypes/union.rs @@ -4,7 +4,7 @@ use crate::builtins::{PyList, PyStr, PyTuple, PyType, PyTypeRef, PyUtf8Str}; use crate::convert::ToPyObject; use crate::function::{ArgBytesLike, FuncArgs, OptionalArg, PySetterValue}; use crate::protocol::{BufferDescriptor, PyBuffer}; -use crate::stdlib::warnings; +use crate::stdlib::_warnings; use crate::types::{AsBuffer, Constructor, Initializer, SetAttr}; use crate::{AsObject, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine}; use alloc::borrow::Cow; @@ -177,7 +177,7 @@ impl PyCUnionType { Python 3.19.", cls.name(), ); - warnings::warn(vm.ctx.exceptions.deprecation_warning, msg, 1, vm)?; + _warnings::warn(vm.ctx.exceptions.deprecation_warning, msg, 1, vm)?; } } @@ -209,9 +209,9 @@ impl PyCUnionType { .ok_or_else(|| vm.new_type_error("_fields_ must contain tuples"))?; if field_tuple.len() < 2 { - return Err(vm.new_type_error( - "_fields_ tuple must have at least 2 elements (name, type)".to_string(), - )); + return Err( + vm.new_type_error("_fields_ tuple must have at least 2 elements (name, type)") + ); } let name = field_tuple @@ -286,9 +286,7 @@ impl PyCUnionType { .try_int(vm)? .as_bigint() .to_u16() - .ok_or_else(|| { - vm.new_value_error("number of bits invalid for bit field".to_string()) - })?; + .ok_or_else(|| vm.new_value_error("number of bits invalid for bit field"))?; has_bitfield = true; // Union fields all start at offset 0, so bit_offset = 0 @@ -329,9 +327,7 @@ impl PyCUnionType { if let Some(stg_info) = cls.get_type_data::() && stg_info.is_final() { - return Err( - vm.new_attribute_error("Structure or union cannot contain itself".to_string()) - ); + return Err(vm.new_attribute_error("Structure or union cannot contain itself")); } // Store StgInfo with aligned size diff --git a/crates/vm/src/stdlib/functools.rs b/crates/vm/src/stdlib/_functools.rs similarity index 99% rename from crates/vm/src/stdlib/functools.rs rename to crates/vm/src/stdlib/_functools.rs index 2c3f70ab52a..494f0e7fd83 100644 --- a/crates/vm/src/stdlib/functools.rs +++ b/crates/vm/src/stdlib/_functools.rs @@ -79,7 +79,7 @@ mod _functools { fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { if !args.args.is_empty() || !args.kwargs.is_empty() { - return Err(vm.new_type_error("_PlaceholderType takes no arguments".to_owned())); + return Err(vm.new_type_error("_PlaceholderType takes no arguments")); } // Return the singleton stored on the type class if let Some(instance) = cls.get_attr(vm.ctx.intern_str("_instance")) { @@ -104,7 +104,7 @@ mod _functools { #[pymethod] fn __init_subclass__(_cls: PyTypeRef, vm: &VirtualMachine) -> PyResult<()> { - Err(vm.new_type_error("cannot subclass '_PlaceholderType'".to_owned())) + Err(vm.new_type_error("cannot subclass '_PlaceholderType'")) } } @@ -140,7 +140,7 @@ mod _functools { #[pyclass( with(Constructor, Callable, GetDescriptor, Representable), - flags(BASETYPE, HAS_DICT) + flags(BASETYPE, HAS_DICT, HAS_WEAKREF) )] impl PyPartial { #[pygetset] @@ -245,7 +245,7 @@ mod _functools { // Validate no trailing placeholders let args_slice = args_tuple.as_slice(); if !args_slice.is_empty() && is_placeholder(args_slice.last().unwrap()) { - return Err(vm.new_type_error("trailing Placeholders are not allowed".to_owned())); + return Err(vm.new_type_error("trailing Placeholders are not allowed")); } let phcount = count_placeholders(args_slice); @@ -354,7 +354,7 @@ mod _functools { // Trailing placeholders are not allowed if !final_args.is_empty() && is_placeholder(final_args.last().unwrap()) { - return Err(vm.new_type_error("trailing Placeholders are not allowed".to_owned())); + return Err(vm.new_type_error("trailing Placeholders are not allowed")); } let phcount = count_placeholders(&final_args); diff --git a/crates/vm/src/stdlib/imp.rs b/crates/vm/src/stdlib/_imp.rs similarity index 88% rename from crates/vm/src/stdlib/imp.rs rename to crates/vm/src/stdlib/_imp.rs index 087556c8cf2..c0acb304a64 100644 --- a/crates/vm/src/stdlib/imp.rs +++ b/crates/vm/src/stdlib/_imp.rs @@ -10,13 +10,13 @@ pub use crate::vm::resolve_frozen_alias; #[cfg(feature = "threading")] #[pymodule(sub)] mod lock { - use crate::{PyResult, VirtualMachine, stdlib::thread::RawRMutex}; + use crate::{PyResult, VirtualMachine, stdlib::_thread::RawRMutex}; static IMP_LOCK: RawRMutex = RawRMutex::INIT; #[pyfunction] fn acquire_lock(_vm: &VirtualMachine) { - IMP_LOCK.lock() + acquire_lock_for_fork() } #[pyfunction] @@ -34,6 +34,16 @@ mod lock { IMP_LOCK.is_locked() } + pub(super) fn acquire_lock_for_fork() { + IMP_LOCK.lock(); + } + + pub(super) fn release_lock_after_fork_parent() { + if IMP_LOCK.is_locked() && IMP_LOCK.is_owned_by_current_thread() { + unsafe { IMP_LOCK.unlock() }; + } + } + /// Reset import lock after fork() — only if held by a dead thread. /// /// `IMP_LOCK` is a reentrant mutex. If the *current* (surviving) thread @@ -47,22 +57,44 @@ mod lock { pub(crate) unsafe fn reinit_after_fork() { if IMP_LOCK.is_locked() && !IMP_LOCK.is_owned_by_current_thread() { // Held by a dead thread — reset to unlocked. - // Same pattern as RLock::_at_fork_reinit in thread.rs. - unsafe { - let old: &crossbeam_utils::atomic::AtomicCell = - core::mem::transmute(&IMP_LOCK); - old.swap(RawRMutex::INIT); - } + unsafe { rustpython_common::lock::zero_reinit_after_fork(&IMP_LOCK) }; + } + } + + /// Match CPython's `_PyImport_ReInitLock()` + `_PyImport_ReleaseLock()` + /// behavior in the post-fork child: + /// 1) if ownership metadata is stale (dead owner / changed tid), reset; + /// 2) if current thread owns the lock, release it. + #[cfg(unix)] + pub(super) unsafe fn after_fork_child_reinit_and_release() { + unsafe { reinit_after_fork() }; + if IMP_LOCK.is_locked() && IMP_LOCK.is_owned_by_current_thread() { + unsafe { IMP_LOCK.unlock() }; } } } /// Re-export for fork safety code in posix.rs +#[cfg(feature = "threading")] +pub(crate) fn acquire_imp_lock_for_fork() { + lock::acquire_lock_for_fork(); +} + +#[cfg(feature = "threading")] +pub(crate) fn release_imp_lock_after_fork_parent() { + lock::release_lock_after_fork_parent(); +} + #[cfg(all(unix, feature = "threading"))] pub(crate) unsafe fn reinit_imp_lock_after_fork() { unsafe { lock::reinit_after_fork() } } +#[cfg(all(unix, feature = "threading"))] +pub(crate) unsafe fn after_fork_child_imp_lock_release() { + unsafe { lock::after_fork_child_reinit_and_release() } +} + #[cfg(not(feature = "threading"))] #[pymodule(sub)] mod lock { diff --git a/crates/vm/src/stdlib/io.rs b/crates/vm/src/stdlib/_io.rs similarity index 99% rename from crates/vm/src/stdlib/io.rs rename to crates/vm/src/stdlib/_io.rs index 945042bc9e4..c238dda3725 100644 --- a/crates/vm/src/stdlib/io.rs +++ b/crates/vm/src/stdlib/_io.rs @@ -413,7 +413,10 @@ mod _io { #[derive(Debug, Default, PyPayload)] pub struct _IOBase; - #[pyclass(with(IterNext, Iterable, Destructor), flags(BASETYPE, HAS_DICT))] + #[pyclass( + with(IterNext, Iterable, Destructor), + flags(BASETYPE, HAS_DICT, HAS_WEAKREF) + )] impl _IOBase { #[pymethod] fn seek( @@ -634,7 +637,7 @@ mod _io { #[repr(transparent)] pub(super) struct _RawIOBase(_IOBase); - #[pyclass(flags(BASETYPE, HAS_DICT))] + #[pyclass(flags(BASETYPE, HAS_DICT, HAS_WEAKREF))] impl _RawIOBase { #[pymethod] fn read(instance: PyObjectRef, size: OptionalSize, vm: &VirtualMachine) -> PyResult { @@ -720,7 +723,7 @@ mod _io { #[repr(transparent)] struct _BufferedIOBase(_IOBase); - #[pyclass(flags(BASETYPE))] + #[pyclass(flags(BASETYPE, HAS_WEAKREF))] impl _BufferedIOBase { #[pymethod] fn read(zelf: PyObjectRef, _size: OptionalArg, vm: &VirtualMachine) -> PyResult { @@ -785,7 +788,7 @@ mod _io { #[repr(transparent)] struct _TextIOBase(_IOBase); - #[pyclass(flags(BASETYPE))] + #[pyclass(flags(BASETYPE, HAS_WEAKREF))] impl _TextIOBase { #[pygetset] fn encoding(_zelf: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { @@ -1577,7 +1580,7 @@ mod _io { fn lock(&self, vm: &VirtualMachine) -> PyResult> { self.data() - .lock() + .lock_wrapped(|do_lock| vm.allow_threads(do_lock)) .ok_or_else(|| vm.new_runtime_error("reentrant call inside buffered io")) } @@ -1981,7 +1984,7 @@ mod _io { #[pyclass( with(Constructor, BufferedMixin, BufferedReadable, Destructor), - flags(BASETYPE, HAS_DICT) + flags(BASETYPE, HAS_DICT, HAS_WEAKREF) )] impl BufferedReader {} @@ -2029,7 +2032,7 @@ mod _io { // Yield to other threads std::thread::yield_now(); } - return Err(vm.new_value_error("write to closed file".to_owned())); + return Err(vm.new_value_error("write to closed file")); } let mut data = self.writer().lock(vm)?; let raw = data.check_init(vm)?; @@ -2085,7 +2088,7 @@ mod _io { #[pyclass( with(Constructor, BufferedMixin, BufferedWritable, Destructor), - flags(BASETYPE, HAS_DICT) + flags(BASETYPE, HAS_DICT, HAS_WEAKREF) )] impl BufferedWriter {} @@ -2159,7 +2162,7 @@ mod _io { BufferedWritable, Destructor ), - flags(BASETYPE, HAS_DICT) + flags(BASETYPE, HAS_DICT, HAS_WEAKREF) )] impl BufferedRandom {} @@ -2229,7 +2232,7 @@ mod _io { BufferedWritable, Destructor ), - flags(BASETYPE, HAS_DICT) + flags(BASETYPE, HAS_DICT, HAS_WEAKREF) )] impl BufferedRWPair { #[pymethod] @@ -2809,7 +2812,7 @@ mod _io { vm: &VirtualMachine, ) -> PyResult>> { self.data - .lock() + .lock_wrapped(|do_lock| vm.allow_threads(do_lock)) .ok_or_else(|| vm.new_runtime_error("reentrant call inside textio")) } @@ -3015,7 +3018,7 @@ mod _io { IterNext, Representable ), - flags(BASETYPE) + flags(BASETYPE, HAS_WEAKREF) )] impl TextIOWrapper { #[pymethod] @@ -4016,7 +4019,7 @@ mod _io { return Ok(vm.ctx.new_str(Wtf8Buf::from(format!("<{type_name}>")))); }; let Some(data) = data.as_ref() else { - return Err(vm.new_value_error("I/O operation on uninitialized object".to_owned())); + return Err(vm.new_value_error("I/O operation on uninitialized object")); }; let mut result = Wtf8Buf::from(format!("<{type_name}")); @@ -4155,7 +4158,7 @@ mod _io { vm: &VirtualMachine, ) -> PyResult>> { self.data - .lock() + .lock_wrapped(|do_lock| vm.allow_threads(do_lock)) .ok_or_else(|| vm.new_runtime_error("reentrant call inside nldecoder")) } @@ -4376,7 +4379,7 @@ mod _io { } } - #[pyclass(flags(BASETYPE, HAS_DICT), with(Constructor, Initializer))] + #[pyclass(flags(BASETYPE, HAS_DICT, HAS_WEAKREF), with(Constructor, Initializer))] impl StringIO { #[pymethod] const fn readable(&self) -> bool { @@ -4569,9 +4572,9 @@ mod _io { fn init(zelf: PyRef, args: Self::Args, vm: &VirtualMachine) -> PyResult<()> { if zelf.exports.load() > 0 { - return Err(vm.new_buffer_error( - "Existing exports of data: object cannot be re-sized".to_owned(), - )); + return Err( + vm.new_buffer_error("Existing exports of data: object cannot be re-sized") + ); } let raw_bytes = args @@ -4593,7 +4596,10 @@ mod _io { } } - #[pyclass(flags(BASETYPE, HAS_DICT), with(PyRef, Constructor, Initializer))] + #[pyclass( + flags(BASETYPE, HAS_DICT, HAS_WEAKREF), + with(PyRef, Constructor, Initializer) + )] impl BytesIO { #[pymethod] const fn readable(&self) -> bool { @@ -4706,9 +4712,9 @@ mod _io { #[pymethod] fn close(&self, vm: &VirtualMachine) -> PyResult<()> { if self.exports.load() > 0 { - return Err(vm.new_buffer_error( - "Existing exports of data: object cannot be closed".to_owned(), - )); + return Err( + vm.new_buffer_error("Existing exports of data: object cannot be closed") + ); } self.closed.store(true); Ok(()) @@ -4788,7 +4794,7 @@ mod _io { #[pymethod] fn getbuffer(self, vm: &VirtualMachine) -> PyResult { if self.closed.load() { - return Err(vm.new_value_error("I/O operation on closed file.".to_owned())); + return Err(vm.new_value_error("I/O operation on closed file.")); } let len = self.buffer.read().cursor.get_ref().len(); let buffer = PyBuffer::new( @@ -5009,13 +5015,13 @@ mod _io { if let Some(tio) = obj.downcast_ref::() { unsafe { reinit_thread_mutex_after_fork(&tio.data) }; - if let Some(guard) = tio.data.lock() { - if let Some(ref data) = *guard { - if let Some(ref decoder) = data.decoder { - reinit_io_locks(decoder); - } - reinit_io_locks(&data.buffer); + if let Some(guard) = tio.data.lock() + && let Some(ref data) = *guard + { + if let Some(ref decoder) = data.decoder { + reinit_io_locks(decoder); } + reinit_io_locks(&data.buffer); } return; } @@ -5038,7 +5044,6 @@ mod _io { if let Some(brw) = obj.downcast_ref::() { unsafe { reinit_thread_mutex_after_fork(&brw.read.data) }; unsafe { reinit_thread_mutex_after_fork(&brw.write.data) }; - return; } } @@ -5117,7 +5122,7 @@ mod _io { // Warn if line buffering is requested in binary mode if opts.buffering == 1 && matches!(mode.encode, EncodeMode::Bytes) { - crate::stdlib::warnings::warn( + crate::stdlib::_warnings::warn( vm.ctx.exceptions.runtime_warning, "line buffering (buffering=1) isn't supported in binary mode, the default buffer size will be used".to_owned(), 1, @@ -5243,7 +5248,7 @@ mod _io { } } let stacklevel = usize::try_from(stacklevel).unwrap_or(0); - crate::stdlib::warnings::warn( + crate::stdlib::_warnings::warn( vm.ctx.exceptions.encoding_warning, "'encoding' argument not specified".to_owned(), stacklevel, @@ -5331,7 +5336,7 @@ mod fileio { types::{Constructor, DefaultConstructor, Destructor, Initializer, Representable}, }; use crossbeam_utils::atomic::AtomicCell; - use std::io::{Read, Write}; + use std::io::Read; bitflags::bitflags! { #[derive(Copy, Clone, Debug, PartialEq)] @@ -5483,7 +5488,7 @@ mod fileio { let name = args.name; // Check if bool is used as file descriptor if name.class().is(vm.ctx.types.bool_type) { - crate::stdlib::warnings::warn( + crate::stdlib::_warnings::warn( vm.ctx.exceptions.runtime_warning, "bool is used as a file descriptor".to_owned(), 1, @@ -5634,7 +5639,7 @@ mod fileio { #[pyclass( with(Constructor, Initializer, Representable, Destructor), - flags(BASETYPE, HAS_DICT) + flags(BASETYPE, HAS_DICT, HAS_WEAKREF) )] impl FileIO { fn io_error( @@ -5735,12 +5740,12 @@ mod fileio { "File or stream is not readable".to_owned(), )); } - let mut handle = zelf.get_fd(vm)?; + let handle = zelf.get_fd(vm)?; let bytes = if let Some(read_byte) = read_byte.to_usize() { let mut bytes = vec![0; read_byte]; // Loop on EINTR (PEP 475) let n = loop { - match handle.read(&mut bytes) { + match vm.allow_threads(|| crt_fd::read(handle, &mut bytes)) { Ok(n) => break n, Err(e) if e.raw_os_error() == Some(libc::EINTR) => { vm.check_signals()?; @@ -5759,7 +5764,10 @@ mod fileio { let mut bytes = vec![]; // Loop on EINTR (PEP 475) loop { - match handle.read_to_end(&mut bytes) { + match vm.allow_threads(|| { + let mut h = handle; + h.read_to_end(&mut bytes) + }) { Ok(_) => break, Err(e) if e.raw_os_error() == Some(libc::EINTR) => { vm.check_signals()?; @@ -5797,10 +5805,9 @@ mod fileio { let handle = zelf.get_fd(vm)?; let mut buf = obj.borrow_buf_mut(); - let mut f = handle.take(buf.len() as _); // Loop on EINTR (PEP 475) let ret = loop { - match f.read(&mut buf) { + match vm.allow_threads(|| crt_fd::read(handle, &mut buf)) { Ok(n) => break n, Err(e) if e.raw_os_error() == Some(libc::EINTR) => { vm.check_signals()?; @@ -5830,11 +5837,11 @@ mod fileio { )); } - let mut handle = zelf.get_fd(vm)?; + let handle = zelf.get_fd(vm)?; // Loop on EINTR (PEP 475) let len = loop { - match obj.with_ref(|b| handle.write(b)) { + match obj.with_ref(|b| vm.allow_threads(|| crt_fd::write(handle, b))) { Ok(n) => break n, Err(e) if e.raw_os_error() == Some(libc::EINTR) => { vm.check_signals()?; @@ -5951,7 +5958,7 @@ mod fileio { .repr(vm) .map(|s| s.as_wtf8().to_owned()) .unwrap_or_else(|_| Wtf8Buf::from("")); - if let Err(e) = crate::stdlib::warnings::warn( + if let Err(e) = crate::stdlib::_warnings::warn( vm.ctx.exceptions.resource_warning, format!("unclosed file {repr}"), 1, @@ -6190,7 +6197,7 @@ mod winconsoleio { // Warn if bool is used as file descriptor if nameobj.class().is(vm.ctx.types.bool_type) { - crate::stdlib::warnings::warn( + crate::stdlib::_warnings::warn( vm.ctx.exceptions.runtime_warning, "bool is used as a file descriptor".to_owned(), 1, @@ -6390,7 +6397,7 @@ mod winconsoleio { #[pyclass( with(Constructor, Initializer, Representable, Destructor), - flags(BASETYPE, HAS_DICT) + flags(BASETYPE, HAS_DICT, HAS_WEAKREF) )] impl WindowsConsoleIO { #[allow(dead_code)] @@ -6504,7 +6511,7 @@ mod winconsoleio { .repr(vm) .map(|s| s.as_wtf8().to_owned()) .unwrap_or_else(|_| Wtf8Buf::from("")); - if let Err(e) = crate::stdlib::warnings::warn( + if let Err(e) = crate::stdlib::_warnings::warn( vm.ctx.exceptions.resource_warning, format!("unclosed file {repr}"), 1, @@ -6970,7 +6977,7 @@ mod winconsoleio { #[pymethod(name = "__reduce__")] fn reduce(_zelf: &Py, vm: &VirtualMachine) -> PyResult { - Err(vm.new_type_error("cannot pickle '_WindowsConsoleIO' instances".to_owned())) + Err(vm.new_type_error("cannot pickle '_WindowsConsoleIO' instances")) } } diff --git a/crates/vm/src/stdlib/operator.rs b/crates/vm/src/stdlib/_operator.rs similarity index 100% rename from crates/vm/src/stdlib/operator.rs rename to crates/vm/src/stdlib/_operator.rs diff --git a/crates/vm/src/stdlib/signal.rs b/crates/vm/src/stdlib/_signal.rs similarity index 99% rename from crates/vm/src/stdlib/signal.rs rename to crates/vm/src/stdlib/_signal.rs index e6ad7b53348..a69d766ce51 100644 --- a/crates/vm/src/stdlib/signal.rs +++ b/crates/vm/src/stdlib/_signal.rs @@ -401,8 +401,7 @@ pub(crate) mod _signal { } // Validate that fd is a valid file descriptor using fstat // First check if SOCKET can be safely cast to i32 (file descriptor) - let fd_i32 = - i32::try_from(fd).map_err(|_| vm.new_value_error("invalid fd".to_owned()))?; + let fd_i32 = i32::try_from(fd).map_err(|_| vm.new_value_error("invalid fd"))?; // Verify the fd is valid by trying to fstat it let borrowed_fd = unsafe { crate::common::crt_fd::Borrowed::try_borrow_raw(fd_i32) } @@ -458,7 +457,7 @@ pub(crate) mod _signal { if let OptionalArg::Present(obj) = siginfo && !vm.is_none(&obj) { - return Err(vm.new_type_error("siginfo must be None".to_owned())); + return Err(vm.new_type_error("siginfo must be None")); } let flags = flags.unwrap_or(0); diff --git a/crates/vm/src/stdlib/sre.rs b/crates/vm/src/stdlib/_sre.rs similarity index 99% rename from crates/vm/src/stdlib/sre.rs rename to crates/vm/src/stdlib/_sre.rs index 2c18bab4ba1..ba7044fb5a9 100644 --- a/crates/vm/src/stdlib/sre.rs +++ b/crates/vm/src/stdlib/_sre.rs @@ -212,7 +212,7 @@ mod _sre { }; } - #[pyclass(with(Hashable, Comparable, Representable))] + #[pyclass(with(Hashable, Comparable, Representable), flags(HAS_WEAKREF))] impl Pattern { fn with_str(string: &PyObject, vm: &VirtualMachine, f: F) -> PyResult where diff --git a/crates/vm/src/stdlib/stat.rs b/crates/vm/src/stdlib/_stat.rs similarity index 100% rename from crates/vm/src/stdlib/stat.rs rename to crates/vm/src/stdlib/_stat.rs diff --git a/crates/vm/src/stdlib/string.rs b/crates/vm/src/stdlib/_string.rs similarity index 100% rename from crates/vm/src/stdlib/string.rs rename to crates/vm/src/stdlib/_string.rs diff --git a/crates/vm/src/stdlib/symtable.rs b/crates/vm/src/stdlib/_symtable.rs similarity index 100% rename from crates/vm/src/stdlib/symtable.rs rename to crates/vm/src/stdlib/_symtable.rs diff --git a/crates/vm/src/stdlib/sysconfig.rs b/crates/vm/src/stdlib/_sysconfig.rs similarity index 100% rename from crates/vm/src/stdlib/sysconfig.rs rename to crates/vm/src/stdlib/_sysconfig.rs diff --git a/crates/vm/src/stdlib/sysconfigdata.rs b/crates/vm/src/stdlib/_sysconfigdata.rs similarity index 100% rename from crates/vm/src/stdlib/sysconfigdata.rs rename to crates/vm/src/stdlib/_sysconfigdata.rs diff --git a/crates/vm/src/stdlib/thread.rs b/crates/vm/src/stdlib/_thread.rs similarity index 62% rename from crates/vm/src/stdlib/thread.rs rename to crates/vm/src/stdlib/_thread.rs index 45be328dc1e..765f2537440 100644 --- a/crates/vm/src/stdlib/thread.rs +++ b/crates/vm/src/stdlib/_thread.rs @@ -15,7 +15,7 @@ pub(crate) mod _thread { builtins::{PyDictRef, PyStr, PyTupleRef, PyType, PyTypeRef, PyUtf8StrRef}, common::wtf8::Wtf8Buf, frame::FrameRef, - function::{ArgCallable, Either, FuncArgs, KwArgs, OptionalArg, PySetterValue}, + function::{ArgCallable, FuncArgs, KwArgs, OptionalArg, PySetterValue, TimeoutSeconds}, types::{Constructor, GetAttr, Representable, SetAttr}, }; use alloc::{ @@ -23,11 +23,11 @@ pub(crate) mod _thread { sync::{Arc, Weak}, }; use core::{cell::RefCell, time::Duration}; - use crossbeam_utils::atomic::AtomicCell; use parking_lot::{ RawMutex, RawThreadId, lock_api::{RawMutex as RawMutexT, RawMutexTimed, RawReentrantMutex}, }; + use rustpython_common::str::levenshtein::{MOVE_COST, levenshtein_distance}; use std::thread; // PYTHREAD_NAME: show current thread name @@ -65,36 +65,29 @@ pub(crate) mod _thread { struct AcquireArgs { #[pyarg(any, default = true)] blocking: bool, - #[pyarg(any, default = Either::A(-1.0))] - timeout: Either, + #[pyarg(any, default = TimeoutSeconds::new(-1.0))] + timeout: TimeoutSeconds, } macro_rules! acquire_lock_impl { ($mu:expr, $args:expr, $vm:expr) => {{ let (mu, args, vm) = ($mu, $args, $vm); - let timeout = match args.timeout { - Either::A(f) => f, - Either::B(i) => i as f64, - }; + let timeout = args.timeout.to_secs_f64(); match args.blocking { true if timeout == -1.0 => { - mu.lock(); + vm.allow_threads(|| mu.lock()); Ok(true) } true if timeout < 0.0 => { - Err(vm.new_value_error("timeout value must be positive".to_owned())) + Err(vm + .new_value_error("timeout value must be a non-negative number".to_owned())) } true => { - // modified from std::time::Duration::from_secs_f64 to avoid a panic. - // TODO: put this in the Duration::try_from_object impl, maybe? - let nanos = timeout * 1_000_000_000.0; - if timeout > TIMEOUT_MAX as f64 || nanos < 0.0 || !nanos.is_finite() { - return Err(vm.new_overflow_error( - "timestamp too large to convert to Rust Duration".to_owned(), - )); + if timeout > TIMEOUT_MAX { + return Err(vm.new_overflow_error("timeout value is too large".to_owned())); } - Ok(mu.try_lock_for(Duration::from_secs_f64(timeout))) + Ok(vm.allow_threads(|| mu.try_lock_for(Duration::from_secs_f64(timeout)))) } false if timeout != -1.0 => Err(vm .new_value_error("can't specify a timeout for a non-blocking call".to_owned())), @@ -132,7 +125,7 @@ pub(crate) mod _thread { } } - #[pyclass(with(Constructor, Representable))] + #[pyclass(with(Constructor, Representable), flags(HAS_WEAKREF))] impl Lock { #[pymethod] #[pymethod(name = "acquire_lock")] @@ -150,17 +143,12 @@ pub(crate) mod _thread { Ok(()) } + #[cfg(unix)] #[pymethod] fn _at_fork_reinit(&self, _vm: &VirtualMachine) -> PyResult<()> { - // Reset the mutex to unlocked by directly writing the INIT value. - // Do NOT call unlock() here — after fork(), unlock_slow() would - // try to unpark stale waiters from dead parent threads. - let new_mut = RawMutex::INIT; - unsafe { - let old_mutex: &AtomicCell = core::mem::transmute(&self.mu); - old_mutex.swap(new_mut); - } - + // Overwrite lock state to unlocked. Do NOT call unlock() here — + // after fork(), unlock_slow() would try to unpark stale waiters. + unsafe { rustpython_common::lock::zero_reinit_after_fork(&self.mu) }; Ok(()) } @@ -205,7 +193,7 @@ pub(crate) mod _thread { } } - #[pyclass(with(Representable), flags(BASETYPE))] + #[pyclass(with(Representable), flags(BASETYPE, HAS_WEAKREF))] impl RLock { #[pyslot] fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { @@ -250,18 +238,13 @@ pub(crate) mod _thread { Ok(()) } + #[cfg(unix)] #[pymethod] fn _at_fork_reinit(&self, _vm: &VirtualMachine) -> PyResult<()> { - // Reset the reentrant mutex to unlocked by directly writing INIT. - // Do NOT call unlock() — after fork(), the slow path would try - // to unpark stale waiters from dead parent threads. + // Overwrite lock state to unlocked. Do NOT call unlock() here — + // after fork(), unlock_slow() would try to unpark stale waiters. self.count.store(0, core::sync::atomic::Ordering::Relaxed); - let new_mut = RawRMutex::INIT; - unsafe { - let old_mutex: &AtomicCell = core::mem::transmute(&self.mu); - old_mutex.swap(new_mut); - } - + unsafe { rustpython_common::lock::zero_reinit_after_fork(&self.mu) }; Ok(()) } @@ -307,7 +290,7 @@ pub(crate) mod _thread { if count == 0 { return Ok(()); } - self.mu.lock(); + vm.allow_threads(|| self.mu.lock()); self.count .store(count, core::sync::atomic::Ordering::Relaxed); Ok(()) @@ -344,6 +327,63 @@ pub(crate) mod _thread { current_thread_id() } + #[cfg(all(unix, feature = "threading"))] + #[pyfunction] + fn _stop_the_world_stats(vm: &VirtualMachine) -> PyResult { + let stats = vm.state.stop_the_world.stats_snapshot(); + let d = vm.ctx.new_dict(); + d.set_item("stop_calls", vm.ctx.new_int(stats.stop_calls).into(), vm)?; + d.set_item( + "last_wait_ns", + vm.ctx.new_int(stats.last_wait_ns).into(), + vm, + )?; + d.set_item( + "total_wait_ns", + vm.ctx.new_int(stats.total_wait_ns).into(), + vm, + )?; + d.set_item("max_wait_ns", vm.ctx.new_int(stats.max_wait_ns).into(), vm)?; + d.set_item("poll_loops", vm.ctx.new_int(stats.poll_loops).into(), vm)?; + d.set_item( + "attached_seen", + vm.ctx.new_int(stats.attached_seen).into(), + vm, + )?; + d.set_item( + "forced_parks", + vm.ctx.new_int(stats.forced_parks).into(), + vm, + )?; + d.set_item( + "suspend_notifications", + vm.ctx.new_int(stats.suspend_notifications).into(), + vm, + )?; + d.set_item( + "attach_wait_yields", + vm.ctx.new_int(stats.attach_wait_yields).into(), + vm, + )?; + d.set_item( + "suspend_wait_yields", + vm.ctx.new_int(stats.suspend_wait_yields).into(), + vm, + )?; + d.set_item( + "world_stopped", + vm.ctx.new_bool(stats.world_stopped).into(), + vm, + )?; + Ok(d) + } + + #[cfg(all(unix, feature = "threading"))] + #[pyfunction] + fn _stop_the_world_reset_stats(vm: &VirtualMachine) { + vm.state.stop_the_world.reset_stats(); + } + /// Set the name of the current thread #[pyfunction] fn set_name(name: PyUtf8StrRef) { @@ -389,7 +429,7 @@ pub(crate) mod _thread { /// This is important for fork compatibility - the ID must remain stable after fork #[cfg(unix)] fn current_thread_id() -> u64 { - // pthread_self() like CPython for fork compatibility + // pthread_self() for fork compatibility unsafe { libc::pthread_self() as u64 } } @@ -441,12 +481,68 @@ pub(crate) mod _thread { } #[pyfunction] - fn start_new_thread( - func: ArgCallable, - args: PyTupleRef, - kwargs: OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { + fn start_new_thread(mut f_args: FuncArgs, vm: &VirtualMachine) -> PyResult { + if !f_args.kwargs.is_empty() { + return Err(vm.new_type_error("start_new_thread() takes no keyword arguments")); + } + let given = f_args.args.len(); + if given < 2 { + return Err(vm.new_type_error(format!( + "start_new_thread expected at least 2 arguments, got {given}" + ))); + } + if given > 3 { + return Err(vm.new_type_error(format!( + "start_new_thread expected at most 3 arguments, got {given}" + ))); + } + + let func_obj = f_args.take_positional().unwrap(); + let args_obj = f_args.take_positional().unwrap(); + let kwargs_obj = f_args.take_positional(); + + if func_obj.to_callable().is_none() { + return Err(vm.new_type_error("first arg must be callable")); + } + if !args_obj.fast_isinstance(vm.ctx.types.tuple_type) { + return Err(vm.new_type_error("2nd arg must be a tuple")); + } + if kwargs_obj + .as_ref() + .is_some_and(|obj| !obj.fast_isinstance(vm.ctx.types.dict_type)) + { + return Err(vm.new_type_error("optional 3rd arg must be a dictionary")); + } + + let func: ArgCallable = func_obj.clone().try_into_value(vm)?; + let args: PyTupleRef = args_obj.clone().try_into_value(vm)?; + let kwargs: Option = kwargs_obj.map(|obj| obj.try_into_value(vm)).transpose()?; + + vm.sys_module.get_attr("audit", vm)?.call( + ( + "_thread.start_new_thread", + func_obj, + args_obj, + kwargs + .as_ref() + .map_or_else(|| vm.ctx.none(), |k| k.clone().into()), + ), + vm, + )?; + + if vm + .state + .finalizing + .load(core::sync::atomic::Ordering::Acquire) + { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.python_finalization_error.to_owned(), + "can't create new thread at interpreter shutdown" + .to_owned() + .into(), + )); + } + let args = FuncArgs::new( args.to_vec(), kwargs @@ -466,7 +562,7 @@ pub(crate) mod _thread { .make_spawn_func(move |vm| run_thread(func, args, vm)), ) .map(|handle| thread_to_id(&handle)) - .map_err(|err| vm.new_runtime_error(format!("can't start new thread: {err}"))) + .map_err(|_err| vm.new_runtime_error("can't start new thread")) } fn run_thread(func: ArgCallable, args: FuncArgs, vm: &VirtualMachine) { @@ -584,14 +680,17 @@ pub(crate) mod _thread { }; match handle_to_join { - Some((_, done_event)) => { - // Wait for this thread to finish (infinite timeout) - // Only check done flag to avoid lock ordering issues - // (done_event lock vs inner lock) - let (lock, cvar) = &*done_event; - let mut done = lock.lock(); - while !*done { - cvar.wait(&mut done); + Some((inner, done_event)) => { + if let Err(exc) = ThreadHandle::join_internal(&inner, &done_event, None, vm) { + vm.run_unraisable( + exc, + Some( + "Exception ignored while joining a thread in _thread._shutdown()" + .to_owned(), + ), + vm.ctx.none(), + ); + return; } } None => break, // No more threads to wait on @@ -609,6 +708,24 @@ pub(crate) mod _thread { handles.push((Arc::downgrade(inner), Arc::downgrade(done_event))); } + fn remove_from_shutdown_handles( + vm: &VirtualMachine, + inner: &Arc>, + done_event: &Arc<(parking_lot::Mutex, parking_lot::Condvar)>, + ) { + let mut handles = vm.state.shutdown_handles.lock(); + handles.retain(|(inner_weak, done_event_weak): &ShutdownEntry| { + let Some(registered_inner) = inner_weak.upgrade() else { + return false; + }; + let Some(registered_done_event) = done_event_weak.upgrade() else { + return false; + }; + !(Arc::ptr_eq(®istered_inner, inner) + && Arc::ptr_eq(®istered_done_event, done_event)) + }); + } + #[pyfunction] fn _make_thread_handle(ident: u64, vm: &VirtualMachine) -> PyRef { let handle = ThreadHandle::new(vm); @@ -699,9 +816,7 @@ pub(crate) mod _thread { fn _excepthook(args: crate::PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { // Type check: args must be _ExceptHookArgs let args = args.downcast::().map_err(|_| { - vm.new_type_error( - "_thread._excepthook argument type must be _ExceptHookArgs".to_owned(), - ) + vm.new_type_error("_thread._excepthook argument type must be _ExceptHookArgs") })?; let exc_type = args.exc_type.clone(); @@ -1021,10 +1136,7 @@ pub(crate) mod _thread { /// Reset a parking_lot::Mutex to unlocked state after fork. #[cfg(unix)] fn reinit_parking_lot_mutex(mutex: &parking_lot::Mutex) { - unsafe { - let raw = mutex.raw() as *const parking_lot::RawMutex as *mut u8; - core::ptr::write_bytes(raw, 0, core::mem::size_of::()); - } + unsafe { rustpython_common::lock::zero_reinit_after_fork(mutex.raw()) }; } // Thread handle state enum @@ -1067,7 +1179,7 @@ pub(crate) mod _thread { done_event: Arc<(parking_lot::Mutex, parking_lot::Condvar)>, } - #[pyclass] + #[pyclass(with(Representable))] impl ThreadHandle { fn new(vm: &VirtualMachine) -> Self { let inner = Arc::new(parking_lot::Mutex::new(ThreadHandleInner { @@ -1089,110 +1201,254 @@ pub(crate) mod _thread { Self { inner, done_event } } - #[pygetset] - fn ident(&self) -> u64 { - self.inner.lock().ident - } - - #[pymethod] - fn is_done(&self) -> bool { - self.inner.lock().state == ThreadHandleState::Done - } - - #[pymethod] - fn _set_done(&self) { - self.inner.lock().state = ThreadHandleState::Done; - // Signal waiting threads that this thread is done - let (lock, cvar) = &*self.done_event; - *lock.lock() = true; - cvar.notify_all(); - } - - #[pymethod] - fn join( - &self, - timeout: OptionalArg>>, + fn join_internal( + inner: &Arc>, + done_event: &Arc<(parking_lot::Mutex, parking_lot::Condvar)>, + timeout_duration: Option, vm: &VirtualMachine, ) -> PyResult<()> { - // Convert timeout to Duration (None or negative = infinite wait) - let timeout_duration = match timeout.flatten() { - Some(Either::A(t)) if t >= 0.0 => Some(Duration::from_secs_f64(t)), - Some(Either::B(t)) if t >= 0 => Some(Duration::from_secs(t as u64)), - _ => None, - }; + Self::check_started(inner, vm)?; - // Check for self-join first - { - let inner = self.inner.lock(); - let current_ident = get_ident(); - if inner.ident == current_ident && inner.state == ThreadHandleState::Running { - return Err(vm.new_runtime_error("cannot join current thread".to_owned())); - } - } + let deadline = + timeout_duration.and_then(|timeout| std::time::Instant::now().checked_add(timeout)); // Wait for thread completion using Condvar (supports timeout) // Loop to handle spurious wakeups - let (lock, cvar) = &*self.done_event; + let (lock, cvar) = &**done_event; let mut done = lock.lock(); + // ThreadHandle_join semantics: self-join/finalizing checks + // apply only while target thread has not reported it is exiting yet. + if !*done { + let inner_guard = inner.lock(); + let current_ident = get_ident(); + if inner_guard.ident == current_ident + && inner_guard.state == ThreadHandleState::Running + { + return Err(vm.new_runtime_error("Cannot join current thread")); + } + if vm + .state + .finalizing + .load(core::sync::atomic::Ordering::Acquire) + { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.python_finalization_error.to_owned(), + "cannot join thread at interpreter shutdown" + .to_owned() + .into(), + )); + } + } + while !*done { if let Some(timeout) = timeout_duration { - let result = cvar.wait_for(&mut done, timeout); + let remaining = deadline.map_or(timeout, |deadline| { + deadline.saturating_duration_since(std::time::Instant::now()) + }); + if remaining.is_zero() { + return Ok(()); + } + let result = vm.allow_threads(|| cvar.wait_for(&mut done, remaining)); if result.timed_out() && !*done { // Timeout occurred and done is still false return Ok(()); } } else { // Infinite wait - cvar.wait(&mut done); + vm.allow_threads(|| cvar.wait(&mut done)); } } drop(done); // Thread is done, now perform cleanup let join_handle = { - let mut inner = self.inner.lock(); + let mut inner_guard = inner.lock(); // If already joined, return immediately (idempotent) - if inner.joined { + if inner_guard.joined { return Ok(()); } // If another thread is already joining, wait for them to finish - if inner.joining { - drop(inner); + if inner_guard.joining { + drop(inner_guard); // Wait on done_event - let (lock, cvar) = &*self.done_event; + let (lock, cvar) = &**done_event; let mut done = lock.lock(); while !*done { - cvar.wait(&mut done); + vm.allow_threads(|| cvar.wait(&mut done)); } return Ok(()); } // Mark that we're joining - inner.joining = true; + inner_guard.joining = true; // Take the join handle if available - inner.join_handle.take() + inner_guard.join_handle.take() }; // Perform the actual join outside the lock if let Some(handle) = join_handle { // Ignore the result - panics in spawned threads are already handled - let _ = handle.join(); + let _ = vm.allow_threads(|| handle.join()); } // Mark as joined and clear joining flag { - let mut inner = self.inner.lock(); - inner.joined = true; - inner.joining = false; + let mut inner_guard = inner.lock(); + inner_guard.joined = true; + inner_guard.joining = false; } Ok(()) } + fn check_started( + inner: &Arc>, + vm: &VirtualMachine, + ) -> PyResult<()> { + let state = inner.lock().state; + if matches!( + state, + ThreadHandleState::NotStarted | ThreadHandleState::Starting + ) { + return Err(vm.new_runtime_error("thread not started")); + } + Ok(()) + } + + fn set_done_internal( + inner: &Arc>, + done_event: &Arc<(parking_lot::Mutex, parking_lot::Condvar)>, + vm: &VirtualMachine, + ) -> PyResult<()> { + Self::check_started(inner, vm)?; + { + let mut inner_guard = inner.lock(); + inner_guard.state = ThreadHandleState::Done; + // _set_done() detach path. Dropping the JoinHandle + // detaches the underlying Rust thread. + inner_guard.join_handle = None; + inner_guard.joining = false; + inner_guard.joined = true; + } + remove_from_shutdown_handles(vm, inner, done_event); + + let (lock, cvar) = &**done_event; + *lock.lock() = true; + cvar.notify_all(); + Ok(()) + } + + fn parse_join_timeout( + timeout_obj: Option, + vm: &VirtualMachine, + ) -> PyResult> { + const JOIN_TIMEOUT_MAX_SECONDS: i64 = TIMEOUT_MAX_IN_MICROSECONDS / 1_000_000; + let Some(timeout_obj) = timeout_obj else { + return Ok(None); + }; + + if let Some(t) = timeout_obj.try_index_opt(vm) { + let t: i64 = t?.try_to_primitive(vm).map_err(|_| { + vm.new_overflow_error("timestamp too large to convert to C PyTime_t") + })?; + if !(-JOIN_TIMEOUT_MAX_SECONDS..=JOIN_TIMEOUT_MAX_SECONDS).contains(&t) { + return Err( + vm.new_overflow_error("timestamp too large to convert to C PyTime_t") + ); + } + if t < 0 { + return Ok(None); + } + return Ok(Some(Duration::from_secs(t as u64))); + } + + if let Some(t) = timeout_obj.try_float_opt(vm) { + let t = t?.to_f64(); + if t.is_nan() { + return Err(vm.new_value_error("Invalid value NaN (not a number)")); + } + if !t.is_finite() || !(-TIMEOUT_MAX..=TIMEOUT_MAX).contains(&t) { + return Err(vm.new_overflow_error("timestamp out of range for platform time_t")); + } + if t < 0.0 { + return Ok(None); + } + return Ok(Some(Duration::from_secs_f64(t))); + } + + Err(vm.new_type_error(format!( + "'{}' object cannot be interpreted as an integer or float", + timeout_obj.class().name() + ))) + } + + #[pygetset] + fn ident(&self) -> u64 { + self.inner.lock().ident + } + + #[pymethod] + fn is_done(&self, f_args: FuncArgs, vm: &VirtualMachine) -> PyResult { + if !f_args.kwargs.is_empty() { + return Err(vm.new_type_error("_ThreadHandle.is_done() takes no keyword arguments")); + } + let given = f_args.args.len(); + if given != 0 { + return Err(vm.new_type_error(format!( + "_ThreadHandle.is_done() takes no arguments ({given} given)" + ))); + } + + // If completion was observed, perform one-time join cleanup + // before returning True. + let done = { + let (lock, _) = &*self.done_event; + *lock.lock() + }; + if !done { + return Ok(false); + } + Self::join_internal(&self.inner, &self.done_event, Some(Duration::ZERO), vm)?; + Ok(true) + } + + #[pymethod] + fn _set_done(&self, f_args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + if !f_args.kwargs.is_empty() { + return Err( + vm.new_type_error("_ThreadHandle._set_done() takes no keyword arguments") + ); + } + let given = f_args.args.len(); + if given != 0 { + return Err(vm.new_type_error(format!( + "_ThreadHandle._set_done() takes no arguments ({given} given)" + ))); + } + + Self::set_done_internal(&self.inner, &self.done_event, vm) + } + + #[pymethod] + fn join(&self, mut f_args: FuncArgs, vm: &VirtualMachine) -> PyResult<()> { + if !f_args.kwargs.is_empty() { + return Err(vm.new_type_error("_ThreadHandle.join() takes no keyword arguments")); + } + let given = f_args.args.len(); + if given > 1 { + return Err( + vm.new_type_error(format!("join() takes at most 1 argument ({given} given)")) + ); + } + let timeout = f_args.take_positional().filter(|obj| !vm.is_none(obj)); + let timeout_duration = Self::parse_join_timeout(timeout, vm)?; + Self::join_internal(&self.inner, &self.done_event, timeout_duration, vm) + } + #[pyslot] fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { ThreadHandle::new(vm) @@ -1201,38 +1457,174 @@ pub(crate) mod _thread { } } - #[derive(FromArgs)] - struct StartJoinableThreadArgs { - #[pyarg(positional)] - function: ArgCallable, - #[pyarg(any, optional)] - handle: OptionalArg>, - #[pyarg(any, default = true)] - daemon: bool, + impl Representable for ThreadHandle { + fn repr_str(zelf: &Py, _vm: &VirtualMachine) -> PyResult { + let ident = zelf.inner.lock().ident; + Ok(format!( + "<{} object: ident={ident}>", + zelf.class().slot_name() + )) + } } #[pyfunction] fn start_joinable_thread( - args: StartJoinableThreadArgs, + mut f_args: FuncArgs, vm: &VirtualMachine, ) -> PyResult> { - let handle = match args.handle { - OptionalArg::Present(h) => h, - OptionalArg::Missing => ThreadHandle::new(vm).into_ref(&vm.ctx), + let given = f_args.args.len() + f_args.kwargs.len(); + if given > 3 { + return Err(vm.new_type_error(format!( + "start_joinable_thread() takes at most 3 arguments ({given} given)" + ))); + } + + let function_pos = f_args.take_positional(); + let function_kw = f_args.take_keyword("function"); + if function_pos.is_some() && function_kw.is_some() { + return Err(vm.new_type_error( + "argument for start_joinable_thread() given by name ('function') and position (1)", + )); + } + let Some(function_obj) = function_pos.or(function_kw) else { + return Err(vm.new_type_error( + "start_joinable_thread() missing required argument 'function' (pos 1)", + )); }; - // Mark as starting - handle.inner.lock().state = ThreadHandleState::Starting; + let handle_pos = f_args.take_positional(); + let handle_kw = f_args.take_keyword("handle"); + if handle_pos.is_some() && handle_kw.is_some() { + return Err(vm.new_type_error( + "argument for start_joinable_thread() given by name ('handle') and position (2)", + )); + } + let handle_obj = handle_pos.or(handle_kw); + + let daemon_pos = f_args.take_positional(); + let daemon_kw = f_args.take_keyword("daemon"); + if daemon_pos.is_some() && daemon_kw.is_some() { + return Err(vm.new_type_error( + "argument for start_joinable_thread() given by name ('daemon') and position (3)", + )); + } + let daemon = daemon_pos + .or(daemon_kw) + .map_or(Ok(true), |obj| obj.try_to_bool(vm))?; + + // Match CPython parser precedence: + // - required positional/keyword argument errors are raised before + // unknown keyword errors when `function` is missing. + if let Some(unexpected) = f_args.kwargs.keys().next() { + let suggestion = ["function", "handle", "daemon"] + .iter() + .filter_map(|candidate| { + let max_distance = (unexpected.len() + candidate.len() + 3) * MOVE_COST / 6; + let distance = levenshtein_distance( + unexpected.as_bytes(), + candidate.as_bytes(), + max_distance, + ); + (distance <= max_distance).then_some((distance, *candidate)) + }) + .min_by_key(|(distance, _)| *distance) + .map(|(_, candidate)| candidate); + let msg = if let Some(suggestion) = suggestion { + format!( + "start_joinable_thread() got an unexpected keyword argument '{unexpected}'. Did you mean '{suggestion}'?" + ) + } else { + format!("start_joinable_thread() got an unexpected keyword argument '{unexpected}'") + }; + return Err(vm.new_type_error(msg)); + } + + if function_obj.to_callable().is_none() { + return Err(vm.new_type_error("thread function must be callable")); + } + let function: ArgCallable = function_obj.clone().try_into_value(vm)?; + + let thread_handle_type = ThreadHandle::class(&vm.ctx); + let handle = if let Some(handle_obj) = handle_obj { + if vm.is_none(&handle_obj) { + None + } else if !handle_obj.class().is(thread_handle_type) { + return Err(vm.new_type_error("'handle' must be a _ThreadHandle")); + } else { + Some( + handle_obj + .downcast::() + .map_err(|_| vm.new_type_error("'handle' must be a _ThreadHandle"))?, + ) + } + } else { + None + }; + + vm.sys_module.get_attr("audit", vm)?.call( + ( + "_thread.start_joinable_thread", + function_obj, + daemon, + handle + .as_ref() + .map_or_else(|| vm.ctx.none(), |h| h.clone().into()), + ), + vm, + )?; + + if vm + .state + .finalizing + .load(core::sync::atomic::Ordering::Acquire) + { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.python_finalization_error.to_owned(), + "can't create new thread at interpreter shutdown" + .to_owned() + .into(), + )); + } + + let handle = match handle { + Some(h) => h, + None => ThreadHandle::new(vm).into_ref(&vm.ctx), + }; + + // Must only start once (ThreadHandle_start). + { + let mut inner = handle.inner.lock(); + if inner.state != ThreadHandleState::NotStarted { + return Err(vm.new_runtime_error("thread already started")); + } + inner.state = ThreadHandleState::Starting; + inner.ident = 0; + inner.join_handle = None; + inner.joining = false; + inner.joined = false; + } + // Starting a handle always resets the completion event. + { + let (done_lock, _) = &*handle.done_event; + *done_lock.lock() = false; + } // Add non-daemon threads to shutdown registry so _shutdown() will wait for them - if !args.daemon { + if !daemon { add_to_shutdown_handles(vm, &handle.inner, &handle.done_event); } - let func = args.function; + let func = function; let handle_clone = handle.clone(); let inner_clone = handle.inner.clone(); let done_event_clone = handle.done_event.clone(); + // Use std::sync (pthread-based) instead of parking_lot for these + // events so they remain fork-safe without the parking_lot_core patch. + let started_event = Arc::new((std::sync::Mutex::new(false), std::sync::Condvar::new())); + let started_event_clone = Arc::clone(&started_event); + let handle_ready_event = + Arc::new((std::sync::Mutex::new(false), std::sync::Condvar::new())); + let handle_ready_event_clone = Arc::clone(&handle_ready_event); let mut thread_builder = thread::Builder::new(); let stacksize = vm.state.stacksize.load(); @@ -1242,11 +1634,27 @@ pub(crate) mod _thread { let join_handle = thread_builder .spawn(vm.new_thread().make_spawn_func(move |vm| { - // Set ident and mark as running + // Publish ident for the parent starter thread. + { + inner_clone.lock().ident = get_ident(); + } + { + let (started_lock, started_cvar) = &*started_event_clone; + *started_lock.lock().unwrap() = true; + started_cvar.notify_all(); + } + // Don't execute the target function until parent marks the + // handle as running. { - let mut inner = inner_clone.lock(); - inner.ident = get_ident(); - inner.state = ThreadHandleState::Running; + let (ready_lock, ready_cvar) = &*handle_ready_event_clone; + let mut ready = ready_lock.lock().unwrap(); + while !*ready { + // Short timeout so we stay responsive to STW requests. + let (guard, _) = ready_cvar + .wait_timeout(ready, core::time::Duration::from_millis(1)) + .unwrap(); + ready = guard; + } } // Ensure cleanup happens even if the function panics @@ -1272,6 +1680,9 @@ pub(crate) mod _thread { vm_state.thread_count.fetch_sub(1); + // The runtime no longer needs to wait for this thread. + remove_from_shutdown_handles(vm, &inner_for_cleanup, &done_event_for_cleanup); + // Signal waiting threads that this thread is done // This must be LAST to ensure all cleanup is complete before join() returns { @@ -1297,10 +1708,52 @@ pub(crate) mod _thread { } } })) - .map_err(|err| vm.new_runtime_error(format!("can't start new thread: {err}")))?; + .map_err(|_err| { + // force_done + remove_from_shutdown_handles on start failure. + { + let mut inner = handle.inner.lock(); + inner.state = ThreadHandleState::Done; + inner.join_handle = None; + inner.joining = false; + inner.joined = true; + } + { + let (done_lock, done_cvar) = &*handle.done_event; + *done_lock.lock() = true; + done_cvar.notify_all(); + } + if !daemon { + remove_from_shutdown_handles(vm, &handle.inner, &handle.done_event); + } + vm.new_runtime_error("can't start new thread") + })?; + + // Wait until the new thread has reported its ident. + { + let (started_lock, started_cvar) = &*started_event; + let mut started = started_lock.lock().unwrap(); + while !*started { + let (guard, _) = started_cvar + .wait_timeout(started, core::time::Duration::from_millis(1)) + .unwrap(); + started = guard; + } + } + + // Mark the handle running in the parent thread (like CPython's + // ThreadHandle_start sets THREAD_HANDLE_RUNNING after spawn succeeds). + { + let mut inner = handle.inner.lock(); + inner.join_handle = Some(join_handle); + inner.state = ThreadHandleState::Running; + } - // Store the join handle - handle.inner.lock().join_handle = Some(join_handle); + // Unblock the started thread once handle state is fully published. + { + let (ready_lock, ready_cvar) = &*handle_ready_event; + *ready_lock.lock().unwrap() = true; + ready_cvar.notify_all(); + } Ok(handle_clone) } diff --git a/crates/vm/src/stdlib/typing.rs b/crates/vm/src/stdlib/_typing.rs similarity index 96% rename from crates/vm/src/stdlib/typing.rs rename to crates/vm/src/stdlib/_typing.rs index 79064346e27..7467a7f2574 100644 --- a/crates/vm/src/stdlib/typing.rs +++ b/crates/vm/src/stdlib/_typing.rs @@ -99,7 +99,7 @@ pub(crate) mod decl { type Args = FuncArgs; fn slot_new(_cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { - Err(vm.new_type_error("cannot create '_typing._ConstEvaluator' instances".to_owned())) + Err(vm.new_type_error("cannot create '_typing._ConstEvaluator' instances")) } fn py_new(_cls: &Py, _args: Self::Args, _vm: &VirtualMachine) -> PyResult { @@ -276,9 +276,7 @@ pub(crate) mod decl { fn __getitem__(zelf: PyRef, args: PyObjectRef, vm: &VirtualMachine) -> PyResult { if zelf.type_params.is_empty() { - return Err( - vm.new_type_error("Only generic type aliases are subscriptable".to_owned()) - ); + return Err(vm.new_type_error("Only generic type aliases are subscriptable")); } let args_tuple = if let Ok(tuple) = args.try_to_ref::(vm) { tuple.to_owned() @@ -373,16 +371,13 @@ pub(crate) mod decl { let name = if !args.args.is_empty() { if args.kwargs.contains_key("name") { return Err(vm.new_type_error( - "argument for typealias() given by name ('name') and position (1)" - .to_owned(), + "argument for typealias() given by name ('name') and position (1)", )); } args.args[0].clone() } else { args.kwargs.get("name").cloned().ok_or_else(|| { - vm.new_type_error( - "typealias() missing required argument 'name' (pos 1)".to_owned(), - ) + vm.new_type_error("typealias() missing required argument 'name' (pos 1)") })? }; @@ -390,16 +385,13 @@ pub(crate) mod decl { let value = if args.args.len() >= 2 { if args.kwargs.contains_key("value") { return Err(vm.new_type_error( - "argument for typealias() given by name ('value') and position (2)" - .to_owned(), + "argument for typealias() given by name ('value') and position (2)", )); } args.args[1].clone() } else { args.kwargs.get("value").cloned().ok_or_else(|| { - vm.new_type_error( - "typealias() missing required argument 'value' (pos 2)".to_owned(), - ) + vm.new_type_error("typealias() missing required argument 'value' (pos 2)") })? }; @@ -414,7 +406,7 @@ pub(crate) mod decl { let tp = tp .clone() .downcast::() - .map_err(|_| vm.new_type_error("type_params must be a tuple".to_owned()))?; + .map_err(|_| vm.new_type_error("type_params must be a tuple"))?; Self::check_type_params(&tp, vm)?; tp } else { diff --git a/crates/vm/src/stdlib/warnings.rs b/crates/vm/src/stdlib/_warnings.rs similarity index 92% rename from crates/vm/src/stdlib/warnings.rs rename to crates/vm/src/stdlib/_warnings.rs index 9b846725347..a41ce2625c7 100644 --- a/crates/vm/src/stdlib/warnings.rs +++ b/crates/vm/src/stdlib/_warnings.rs @@ -71,7 +71,7 @@ mod _warnings { #[pyfunction] fn _release_lock(vm: &VirtualMachine) -> PyResult<()> { if !vm.state.warnings.release_lock() { - return Err(vm.new_runtime_error("cannot release un-acquired lock".to_owned())); + return Err(vm.new_runtime_error("cannot release un-acquired lock")); } Ok(()) } @@ -140,9 +140,7 @@ mod _warnings { if let Some(ref prefixes) = skip_prefixes { for item in prefixes.iter() { if !item.class().is(vm.ctx.types.str_type) { - return Err( - vm.new_type_error("skip_file_prefixes must be a tuple of strs".to_owned()) - ); + return Err(vm.new_type_error("skip_file_prefixes must be a tuple of strs")); } } } @@ -188,17 +186,17 @@ mod _warnings { && !vm.is_none(mg) && !mg.class().is(vm.ctx.types.dict_type) { - return Err(vm.new_type_error("module_globals must be a dict".to_owned())); + return Err(vm.new_type_error("module_globals must be a dict")); } - let category = - if vm.is_none(&args.category) { - None - } else { - Some(PyTypeRef::try_from_object(vm, args.category).map_err(|_| { - vm.new_type_error("category must be a Warning subclass".to_owned()) - })?) - }; + let category = if vm.is_none(&args.category) { + None + } else { + Some( + PyTypeRef::try_from_object(vm, args.category) + .map_err(|_| vm.new_type_error("category must be a Warning subclass"))?, + ) + }; crate::warn::warn_explicit( category, diff --git a/crates/vm/src/stdlib/weakref.rs b/crates/vm/src/stdlib/_weakref.rs similarity index 100% rename from crates/vm/src/stdlib/weakref.rs rename to crates/vm/src/stdlib/_weakref.rs diff --git a/crates/vm/src/stdlib/_winapi.rs b/crates/vm/src/stdlib/_winapi.rs index 1e52af5aaa4..f7d9d0e703f 100644 --- a/crates/vm/src/stdlib/_winapi.rs +++ b/crates/vm/src/stdlib/_winapi.rs @@ -538,7 +538,7 @@ mod _winapi { let ms = if ms < 0 { windows_sys::Win32::System::Threading::INFINITE } else if ms > u32::MAX as i64 { - return Err(vm.new_overflow_error("timeout value is too large".to_owned())); + return Err(vm.new_overflow_error("timeout value is too large")); } else { ms as u32 }; @@ -567,13 +567,11 @@ mod _winapi { .collect(); if handles.is_empty() { - return Err(vm.new_value_error("handle_seq must not be empty".to_owned())); + return Err(vm.new_value_error("handle_seq must not be empty")); } if handles.len() > 64 { - return Err( - vm.new_value_error("WaitForMultipleObjects supports at most 64 handles".to_owned()) - ); + return Err(vm.new_value_error("WaitForMultipleObjects supports at most 64 handles")); } let ret = unsafe { @@ -693,7 +691,7 @@ mod _winapi { let src_wide = src.as_wtf8().to_wide(); if src_wide.len() > i32::MAX as usize { - return Err(vm.new_overflow_error("input string is too long".to_string())); + return Err(vm.new_overflow_error("input string is too long")); } // First call to get required buffer size @@ -901,8 +899,7 @@ mod _winapi { let inner = self.inner.lock(); if !inner.completed { return Err(vm.new_value_error( - "can't get read buffer before GetOverlappedResult() signals the operation completed" - .to_owned(), + "can't get read buffer before GetOverlappedResult() signals the operation completed", )); } Ok(inner @@ -1100,7 +1097,7 @@ mod _winapi { let size = size.unwrap_or(0); if size < 0 { - return Err(vm.new_value_error("negative size".to_string())); + return Err(vm.new_value_error("negative size")); } let mut navail: u32 = 0; @@ -1254,6 +1251,26 @@ mod _winapi { err }; + + // Without GIL, the Python-level PipeConnection._send_bytes has a + // race on _send_ov when the caller (SimpleQueue) skips locking on + // Windows. Wait for completion here so the caller never sees + // ERROR_IO_PENDING and never blocks in WaitForMultipleObjects, + // keeping the _send_ov window negligibly small. + if err == ERROR_IO_PENDING { + let event = ov.inner.lock().overlapped.hEvent; + vm.allow_threads(|| unsafe { + windows_sys::Win32::System::Threading::WaitForSingleObject( + event, + windows_sys::Win32::System::Threading::INFINITE, + ); + }); + let result = vm + .ctx + .new_tuple(vec![ov.into_pyobject(vm), vm.ctx.new_int(0u32).into()]); + return Ok(result.into()); + } + let result = vm .ctx .new_tuple(vec![ov.into_pyobject(vm), vm.ctx.new_int(err).into()]); @@ -1520,7 +1537,7 @@ mod _winapi { #[cfg(feature = "threading")] let sigint_event = { - let is_main = crate::stdlib::thread::get_ident() == vm.state.main_thread_ident.load(); + let is_main = crate::stdlib::_thread::get_ident() == vm.state.main_thread_ident.load(); if is_main { let handle = crate::signal::get_sigint_event().unwrap_or_else(|| { let handle = unsafe { WinCreateEventW(null(), 1, 0, null()) }; @@ -1808,9 +1825,9 @@ mod _winapi { if let Some(ref n) = name && n.as_bytes().contains(&0) { - return Err(vm.new_value_error( - "CreateFileMapping: name must not contain null characters".to_owned(), - )); + return Err( + vm.new_value_error("CreateFileMapping: name must not contain null characters") + ); } let name_wide = name.as_ref().map(|n| n.as_wtf8().to_wide_with_nul()); let name_ptr = name_wide.as_ref().map_or(null(), |n| n.as_ptr()); @@ -1843,9 +1860,9 @@ mod _winapi { use windows_sys::Win32::System::Memory::OpenFileMappingW; if name.as_bytes().contains(&0) { - return Err(vm.new_value_error( - "OpenFileMapping: name must not contain null characters".to_owned(), - )); + return Err( + vm.new_value_error("OpenFileMapping: name must not contain null characters") + ); } let name_wide = name.as_wtf8().to_wide_with_nul(); let handle = unsafe { diff --git a/crates/vm/src/stdlib/_wmi.rs b/crates/vm/src/stdlib/_wmi.rs index f2b088e96a3..96275e5ac4b 100644 --- a/crates/vm/src/stdlib/_wmi.rs +++ b/crates/vm/src/stdlib/_wmi.rs @@ -567,7 +567,7 @@ mod _wmi { .get(..7) .is_some_and(|s| s.eq_ignore_ascii_case("select ")) { - return Err(vm.new_value_error("only SELECT queries are supported".to_owned())); + return Err(vm.new_value_error("only SELECT queries are supported")); } let query_wide = wide_str(query_str); diff --git a/crates/vm/src/stdlib/atexit.rs b/crates/vm/src/stdlib/atexit.rs index 338fae3b2b7..638927fe90f 100644 --- a/crates/vm/src/stdlib/atexit.rs +++ b/crates/vm/src/stdlib/atexit.rs @@ -7,7 +7,11 @@ mod atexit { #[pyfunction] fn register(func: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyObjectRef { - vm.state.atexit_funcs.lock().push((func.clone(), args)); + // Callbacks go in LIFO order (insert at front) + vm.state + .atexit_funcs + .lock() + .insert(0, Box::new((func.clone(), args))); func } @@ -18,27 +22,62 @@ mod atexit { #[pyfunction] fn unregister(func: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - let mut funcs = vm.state.atexit_funcs.lock(); - - let mut i = 0; - while i < funcs.len() { - if vm.bool_eq(&funcs[i].0, &func)? { - funcs.remove(i); - } else { - i += 1; + // Iterate backward (oldest to newest in LIFO list). + // Release the lock during comparison so __eq__ can call atexit functions. + let mut i = { + let funcs = vm.state.atexit_funcs.lock(); + funcs.len() as isize - 1 + }; + while i >= 0 { + let (cb, entry_ptr) = { + let funcs = vm.state.atexit_funcs.lock(); + if i as usize >= funcs.len() { + i = funcs.len() as isize; + i -= 1; + continue; + } + let entry = &funcs[i as usize]; + (entry.0.clone(), &**entry as *const (PyObjectRef, FuncArgs)) + }; + // Lock released: __eq__ can safely call atexit functions + let eq = vm.bool_eq(&func, &cb)?; + if eq { + // The entry may have moved during __eq__. Search backward by identity. + let mut funcs = vm.state.atexit_funcs.lock(); + let mut j = (funcs.len() as isize - 1).min(i); + while j >= 0 { + if core::ptr::eq(&**funcs.get(j as usize).unwrap(), entry_ptr) { + funcs.remove(j as usize); + i = j; + break; + } + j -= 1; + } } + { + let funcs = vm.state.atexit_funcs.lock(); + if i as usize >= funcs.len() { + i = funcs.len() as isize; + } + } + i -= 1; } - Ok(()) } #[pyfunction] pub fn _run_exitfuncs(vm: &VirtualMachine) { let funcs: Vec<_> = core::mem::take(&mut *vm.state.atexit_funcs.lock()); - for (func, args) in funcs.into_iter().rev() { + // Callbacks stored in LIFO order, iterate forward + for entry in funcs.into_iter() { + let (func, args) = *entry; if let Err(e) = func.call(args, vm) { let exit = e.fast_isinstance(vm.ctx.exceptions.system_exit); - vm.run_unraisable(e, Some("Error in atexit._run_exitfuncs".to_owned()), func); + let msg = func + .repr(vm) + .ok() + .map(|r| format!("Exception ignored in atexit callback {}", r.as_wtf8())); + vm.run_unraisable(e, msg, vm.ctx.none()); if exit { break; } diff --git a/crates/vm/src/stdlib/builtins.rs b/crates/vm/src/stdlib/builtins.rs index 528bd4a50a3..c1f7943e88b 100644 --- a/crates/vm/src/stdlib/builtins.rs +++ b/crates/vm/src/stdlib/builtins.rs @@ -181,7 +181,7 @@ mod builtins { /// Decode source bytes to a string, handling PEP 263 encoding declarations /// and BOM. Raises SyntaxError for invalid UTF-8 without an encoding - /// declaration (matching CPython behavior). + /// declaration. /// Check if an encoding name is a UTF-8 variant after normalization. /// Matches: utf-8, utf_8, utf8, UTF-8, etc. #[cfg(feature = "parser")] @@ -260,7 +260,7 @@ mod builtins { } #[cfg(feature = "ast")] { - use crate::{class::PyClassImpl, stdlib::ast}; + use crate::{class::PyClassImpl, stdlib::_ast}; let feature_version = feature_version_from_arg(args._feature_version, vm)?; @@ -277,25 +277,24 @@ mod builtins { if args .source - .fast_isinstance(&ast::NodeAst::make_static_type()) + .fast_isinstance(&_ast::NodeAst::make_static_type()) { let flags: i32 = args.flags.map_or(Ok(0), |v| v.try_to_primitive(vm))?; - let is_ast_only = !(flags & ast::PY_CF_ONLY_AST).is_zero(); + let is_ast_only = !(flags & _ast::PY_CF_ONLY_AST).is_zero(); // func_type mode requires PyCF_ONLY_AST if mode_str == "func_type" && !is_ast_only { return Err(vm.new_value_error( - "compile() mode 'func_type' requires flag PyCF_ONLY_AST".to_owned(), + "compile() mode 'func_type' requires flag PyCF_ONLY_AST", )); } // compile(ast_node, ..., PyCF_ONLY_AST) returns the AST after validation if is_ast_only { - let (expected_type, expected_name) = ast::mode_type_and_name(mode_str) + let (expected_type, expected_name) = _ast::mode_type_and_name(mode_str) .ok_or_else(|| { vm.new_value_error( - "compile() mode must be 'exec', 'eval', 'single' or 'func_type'" - .to_owned(), + "compile() mode must be 'exec', 'eval', 'single' or 'func_type'", ) })?; if !args.source.fast_isinstance(&expected_type) { @@ -305,7 +304,7 @@ mod builtins { args.source.class().name() ))); } - ast::validate_ast_object(vm, args.source.clone())?; + _ast::validate_ast_object(vm, args.source.clone())?; return Ok(args.source); } @@ -318,7 +317,7 @@ mod builtins { let mode = mode_str .parse::() .map_err(|err| vm.new_value_error(err.to_string()))?; - return ast::compile( + return _ast::compile( vm, args.source, &args.filename.to_string_lossy(), @@ -347,16 +346,16 @@ mod builtins { let flags = args.flags.map_or(Ok(0), |v| v.try_to_primitive(vm))?; - if !(flags & !ast::PY_COMPILE_FLAGS_MASK).is_zero() { + if !(flags & !_ast::PY_COMPILE_FLAGS_MASK).is_zero() { return Err(vm.new_value_error("compile() unrecognized flags")); } - let allow_incomplete = !(flags & ast::PY_CF_ALLOW_INCOMPLETE_INPUT).is_zero(); - let type_comments = !(flags & ast::PY_CF_TYPE_COMMENTS).is_zero(); + let allow_incomplete = !(flags & _ast::PY_CF_ALLOW_INCOMPLETE_INPUT).is_zero(); + let type_comments = !(flags & _ast::PY_CF_TYPE_COMMENTS).is_zero(); let optimize_level = optimize; - if (flags & ast::PY_CF_ONLY_AST).is_zero() { + if (flags & _ast::PY_CF_ONLY_AST).is_zero() { #[cfg(not(feature = "compiler"))] { Err(vm.new_value_error(CODEGEN_NOT_SUPPORTED.to_owned())) @@ -367,7 +366,7 @@ mod builtins { let mode = mode_str .parse::() .map_err(|err| vm.new_value_error(err.to_string()))?; - let _ = ast::parse( + let _ = _ast::parse( vm, source, mode, @@ -399,14 +398,14 @@ mod builtins { } } else { if mode_str == "func_type" { - return ast::parse_func_type(vm, source, optimize_level, feature_version) + return _ast::parse_func_type(vm, source, optimize_level, feature_version) .map_err(|e| (e, Some(source), allow_incomplete).to_pyexception(vm)); } let mode = mode_str .parse::() .map_err(|err| vm.new_value_error(err.to_string()))?; - let parsed = ast::parse( + let parsed = _ast::parse( vm, source, mode, @@ -417,7 +416,7 @@ mod builtins { .map_err(|e| (e, Some(source), allow_incomplete).to_pyexception(vm))?; if mode_str == "single" { - return ast::wrap_interactive(vm, parsed); + return _ast::wrap_interactive(vm, parsed); } Ok(parsed) @@ -500,7 +499,7 @@ mod builtins { "exec() globals must be a dict, not {}", globals.class().name() )), - _ => vm.new_type_error("globals must be a dict".to_owned()), + _ => vm.new_type_error("globals must be a dict"), }); } Ok(()) @@ -999,9 +998,7 @@ mod builtins { }; let write = |obj: PyStrRef| vm.call_method(&file, "write", (obj,)); - let sep = options - .sep - .unwrap_or_else(|| PyStr::from(" ").into_ref(&vm.ctx)); + let sep = options.sep.unwrap_or_else(|| vm.ctx.new_str(" ")); let mut first = true; for object in objects { @@ -1014,9 +1011,7 @@ mod builtins { write(object.str(vm)?)?; } - let end = options - .end - .unwrap_or_else(|| PyStr::from("\n").into_ref(&vm.ctx)); + let end = options.end.unwrap_or_else(|| vm.ctx.new_str("\n")); write(end)?; if options.flush.into() { diff --git a/crates/vm/src/stdlib/gc.rs b/crates/vm/src/stdlib/gc.rs index 82b0c68bd9e..3909186b5c0 100644 --- a/crates/vm/src/stdlib/gc.rs +++ b/crates/vm/src/stdlib/gc.rs @@ -51,15 +51,15 @@ mod gc { let generation = args.generation; let generation_num = generation.unwrap_or(2); if !(0..=2).contains(&generation_num) { - return Err(vm.new_value_error("invalid generation".to_owned())); + return Err(vm.new_value_error("invalid generation")); } // Invoke callbacks with "start" phase - invoke_callbacks(vm, "start", generation_num as usize, 0, 0); + invoke_callbacks(vm, "start", generation_num as usize, &Default::default()); // Manual gc.collect() should run even if GC is disabled let gc = gc_state::gc_state(); - let (collected, uncollectable) = gc.collect_force(generation_num as usize); + let result = gc.collect_force(generation_num as usize); // Move objects from gc_state.garbage to vm.ctx.gc_garbage (for DEBUG_SAVEALL) { @@ -74,15 +74,9 @@ mod gc { } // Invoke callbacks with "stop" phase - invoke_callbacks( - vm, - "stop", - generation_num as usize, - collected, - uncollectable, - ); + invoke_callbacks(vm, "stop", generation_num as usize, &result); - Ok(collected as i32) + Ok((result.collected + result.uncollectable) as i32) } /// Return the current collection thresholds as a tuple. @@ -148,6 +142,8 @@ mod gc { vm.ctx.new_int(stat.uncollectable).into(), vm, )?; + dict.set_item("candidates", vm.ctx.new_int(stat.candidates).into(), vm)?; + dict.set_item("duration", vm.ctx.new_float(stat.duration).into(), vm)?; result.push(dict.into()); } @@ -189,10 +185,49 @@ mod gc { /// Return the list of objects that directly refer to any of the arguments. #[pyfunction] fn get_referrers(args: FuncArgs, vm: &VirtualMachine) -> PyListRef { - // This is expensive: we need to scan all tracked objects - // For now, return an empty list (would need full object tracking to implement) - let _ = args; - vm.ctx.new_list(vec![]) + use std::collections::HashSet; + + // Build a set of target object pointers for fast lookup + let targets: HashSet = args + .args + .iter() + .map(|obj| obj.as_ref() as *const crate::PyObject as usize) + .collect(); + + // Collect pointers of frames currently on the execution stack. + // In CPython, executing frames (_PyInterpreterFrame) are not GC-tracked + // PyObjects, so they never appear in get_referrers results. Since + // RustPython materializes every frame as a PyObject, we must exclude + // them manually to match the expected behavior. + let stack_frames: HashSet = vm + .frames + .borrow() + .iter() + .map(|fp| { + let frame: &crate::PyObject = unsafe { fp.as_ref() }.as_ref(); + frame as *const crate::PyObject as usize + }) + .collect(); + + let mut result = Vec::new(); + + // Scan all tracked objects across all generations + let all_objects = gc_state::gc_state().get_objects(None); + for obj in all_objects { + let obj_ptr = obj.as_ref() as *const crate::PyObject as usize; + if stack_frames.contains(&obj_ptr) { + continue; + } + let referent_ptrs = unsafe { obj.gc_get_referent_ptrs() }; + for child_ptr in referent_ptrs { + if targets.contains(&(child_ptr.as_ptr() as usize)) { + result.push(obj.clone()); + break; + } + } + } + + vm.ctx.new_list(result) } /// Return True if the object is tracked by the garbage collector. @@ -243,8 +278,7 @@ mod gc { vm: &VirtualMachine, phase: &str, generation: usize, - collected: usize, - uncollectable: usize, + result: &gc_state::CollectResult, ) { let callbacks_list = &vm.ctx.gc_callbacks; let callbacks: Vec = callbacks_list.borrow_vec().to_vec(); @@ -255,8 +289,14 @@ mod gc { let phase_str: PyObjectRef = vm.ctx.new_str(phase).into(); let info = vm.ctx.new_dict(); let _ = info.set_item("generation", vm.ctx.new_int(generation).into(), vm); - let _ = info.set_item("collected", vm.ctx.new_int(collected).into(), vm); - let _ = info.set_item("uncollectable", vm.ctx.new_int(uncollectable).into(), vm); + let _ = info.set_item("collected", vm.ctx.new_int(result.collected).into(), vm); + let _ = info.set_item( + "uncollectable", + vm.ctx.new_int(result.uncollectable).into(), + vm, + ); + let _ = info.set_item("candidates", vm.ctx.new_int(result.candidates).into(), vm); + let _ = info.set_item("duration", vm.ctx.new_float(result.duration).into(), vm); for callback in callbacks { let _ = callback.call((phase_str.clone(), info.clone()), vm); diff --git a/crates/vm/src/stdlib/itertools.rs b/crates/vm/src/stdlib/itertools.rs index d1af433d7dc..763c3ddce76 100644 --- a/crates/vm/src/stdlib/itertools.rs +++ b/crates/vm/src/stdlib/itertools.rs @@ -667,7 +667,7 @@ mod decl { groupby: PyRef, } - #[pyclass(with(IterNext, Iterable))] + #[pyclass(with(IterNext, Iterable), flags(HAS_WEAKREF))] impl PyItertoolsGrouper {} impl SelfIter for PyItertoolsGrouper {} diff --git a/crates/vm/src/stdlib/mod.rs b/crates/vm/src/stdlib/mod.rs index f4f266bf161..42514c46dda 100644 --- a/crates/vm/src/stdlib/mod.rs +++ b/crates/vm/src/stdlib/mod.rs @@ -1,35 +1,31 @@ mod _abc; -mod _types; #[cfg(feature = "ast")] -pub(crate) mod ast; +pub(crate) mod _ast; +mod _codecs; +mod _collections; +mod _functools; +mod _imp; +pub mod _io; +mod _operator; +mod _sre; +mod _stat; +mod _string; +#[cfg(feature = "compiler")] +mod _symtable; +mod _sysconfig; +mod _sysconfigdata; +mod _types; +pub mod _typing; +pub mod _warnings; +mod _weakref; pub mod atexit; pub mod builtins; -mod codecs; -mod collections; pub mod errno; -mod functools; mod gc; -mod imp; -pub mod io; mod itertools; mod marshal; -mod operator; -// TODO: maybe make this an extension module, if we ever get those -// mod re; -mod sre; -mod stat; -mod string; -#[cfg(feature = "compiler")] -mod symtable; -mod sysconfig; -mod sysconfigdata; -#[cfg(feature = "threading")] -pub mod thread; pub mod time; mod typevar; -pub mod typing; -pub mod warnings; -mod weakref; #[cfg(feature = "host_env")] #[macro_use] @@ -47,7 +43,7 @@ pub mod posix; any(target_os = "linux", target_os = "macos", target_os = "windows"), not(any(target_env = "musl", target_env = "sgx")) ))] -mod ctypes; +mod _ctypes; #[cfg(all(feature = "host_env", windows))] pub(crate) mod msvcrt; @@ -58,10 +54,12 @@ pub(crate) mod msvcrt; ))] mod pwd; +#[cfg(feature = "host_env")] +pub(crate) mod _signal; +#[cfg(feature = "threading")] +pub mod _thread; #[cfg(all(feature = "host_env", windows))] mod _wmi; -#[cfg(feature = "host_env")] -pub(crate) mod signal; pub mod sys; #[cfg(all(feature = "host_env", windows))] #[path = "_winapi.rs"] @@ -83,28 +81,28 @@ pub fn builtin_module_defs(ctx: &Context) -> Vec<&'static PyModuleDef> { _abc::module_def(ctx), _types::module_def(ctx), #[cfg(feature = "ast")] - ast::module_def(ctx), + _ast::module_def(ctx), atexit::module_def(ctx), - codecs::module_def(ctx), - collections::module_def(ctx), + _codecs::module_def(ctx), + _collections::module_def(ctx), #[cfg(all( feature = "host_env", any(target_os = "linux", target_os = "macos", target_os = "windows"), not(any(target_env = "musl", target_env = "sgx")) ))] - ctypes::module_def(ctx), + _ctypes::module_def(ctx), errno::module_def(ctx), - functools::module_def(ctx), + _functools::module_def(ctx), gc::module_def(ctx), - imp::module_def(ctx), - io::module_def(ctx), + _imp::module_def(ctx), + _io::module_def(ctx), itertools::module_def(ctx), marshal::module_def(ctx), #[cfg(all(feature = "host_env", windows))] msvcrt::module_def(ctx), #[cfg(all(feature = "host_env", windows))] nt::module_def(ctx), - operator::module_def(ctx), + _operator::module_def(ctx), #[cfg(all(feature = "host_env", any(unix, target_os = "wasi")))] posix::module_def(ctx), #[cfg(all(feature = "host_env", not(any(unix, windows, target_os = "wasi"))))] @@ -116,20 +114,20 @@ pub fn builtin_module_defs(ctx: &Context) -> Vec<&'static PyModuleDef> { ))] pwd::module_def(ctx), #[cfg(feature = "host_env")] - signal::module_def(ctx), - sre::module_def(ctx), - stat::module_def(ctx), - string::module_def(ctx), + _signal::module_def(ctx), + _sre::module_def(ctx), + _stat::module_def(ctx), + _string::module_def(ctx), #[cfg(feature = "compiler")] - symtable::module_def(ctx), - sysconfigdata::module_def(ctx), - sysconfig::module_def(ctx), + _symtable::module_def(ctx), + _sysconfigdata::module_def(ctx), + _sysconfig::module_def(ctx), #[cfg(feature = "threading")] - thread::module_def(ctx), + _thread::module_def(ctx), time::module_def(ctx), - typing::module_def(ctx), - warnings::module_def(ctx), - weakref::module_def(ctx), + _typing::module_def(ctx), + _warnings::module_def(ctx), + _weakref::module_def(ctx), #[cfg(all(feature = "host_env", windows))] winapi::module_def(ctx), #[cfg(all(feature = "host_env", windows))] diff --git a/crates/vm/src/stdlib/nt.rs b/crates/vm/src/stdlib/nt.rs index f9fad0a2033..5dd4cf4f001 100644 --- a/crates/vm/src/stdlib/nt.rs +++ b/crates/vm/src/stdlib/nt.rs @@ -364,9 +364,9 @@ pub(crate) mod module { // If path is a file descriptor, use fchmod if let OsPathOrFd::Fd(fd) = path { if follow_symlinks.into_option().is_some() { - return Err(vm.new_value_error( - "chmod: follow_symlinks is not supported with fd argument".to_owned(), - )); + return Err( + vm.new_value_error("chmod: follow_symlinks is not supported with fd argument") + ); } return fchmod_impl(fd.as_raw(), mode, vm); } @@ -1377,7 +1377,7 @@ pub(crate) mod module { let wide = path.to_wide_cstring(vm)?; let buflen = core::cmp::max(wide.len(), Foundation::MAX_PATH as usize); if buflen > u32::MAX as usize { - return Err(vm.new_overflow_error("path too long".to_owned())); + return Err(vm.new_overflow_error("path too long")); } let mut buffer = vec![0u16; buflen]; let ret = unsafe { @@ -2251,7 +2251,7 @@ pub(crate) mod module { // PathBuffer starts at offset 16 (sub_offset, sub_length, 16usize) } else { - return Err(vm.new_value_error("not a symbolic link".to_owned())); + return Err(vm.new_value_error("not a symbolic link")); }; // Extract the substitute name diff --git a/crates/vm/src/stdlib/os.rs b/crates/vm/src/stdlib/os.rs index 0736ca18032..03c5e33de76 100644 --- a/crates/vm/src/stdlib/os.rs +++ b/crates/vm/src/stdlib/os.rs @@ -111,7 +111,7 @@ pub(crate) fn warn_if_bool_fd(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResul .class() .is(crate::builtins::bool_::PyBool::static_type()) { - crate::stdlib::warnings::warn( + crate::stdlib::_warnings::warn( vm.ctx.exceptions.runtime_warning, "bool is used as a file descriptor".to_owned(), 1, @@ -287,7 +287,7 @@ pub(super) mod _os { fn read(fd: crt_fd::Borrowed<'_>, n: usize, vm: &VirtualMachine) -> PyResult { let mut buffer = vec![0u8; n]; loop { - match crt_fd::read(fd, &mut buffer) { + match vm.allow_threads(|| crt_fd::read(fd, &mut buffer)) { Ok(n) => { buffer.truncate(n); return Ok(vm.ctx.new_bytes(buffer)); @@ -309,7 +309,7 @@ pub(super) mod _os { ) -> PyResult { buffer.with_ref(|buf| { loop { - match crt_fd::read(fd, buf) { + match vm.allow_threads(|| crt_fd::read(fd, buf)) { Ok(n) => return Ok(n), Err(e) if e.raw_os_error() == Some(libc::EINTR) => { vm.check_signals()?; @@ -322,8 +322,12 @@ pub(super) mod _os { } #[pyfunction] - fn write(fd: crt_fd::Borrowed<'_>, data: ArgBytesLike) -> io::Result { - data.with_ref(|b| crt_fd::write(fd, b)) + fn write( + fd: crt_fd::Borrowed<'_>, + data: ArgBytesLike, + vm: &VirtualMachine, + ) -> io::Result { + data.with_ref(|b| vm.allow_threads(|| crt_fd::write(fd, b))) } #[cfg(not(windows))] @@ -864,7 +868,7 @@ pub(super) mod _os { #[pymethod] fn __reduce__(&self, vm: &VirtualMachine) -> PyResult { - Err(vm.new_type_error("cannot pickle 'DirEntry' object".to_owned())) + Err(vm.new_type_error("cannot pickle 'DirEntry' object")) } } @@ -927,14 +931,14 @@ pub(super) mod _os { #[pymethod] fn __reduce__(&self, vm: &VirtualMachine) -> PyResult { - Err(vm.new_type_error("cannot pickle 'ScandirIterator' object".to_owned())) + Err(vm.new_type_error("cannot pickle 'ScandirIterator' object")) } } impl Destructor for ScandirIterator { fn del(zelf: &Py, vm: &VirtualMachine) -> PyResult<()> { // Emit ResourceWarning if the iterator is not yet exhausted/closed if zelf.entries.read().is_some() { - let _ = crate::stdlib::warnings::warn( + let _ = crate::stdlib::_warnings::warn( vm.ctx.exceptions.resource_warning, format!("unclosed scandir iterator {:?}", zelf.as_object()), 1, @@ -1089,7 +1093,7 @@ pub(super) mod _os { #[pymethod] fn __reduce__(&self, vm: &VirtualMachine) -> PyResult { - Err(vm.new_type_error("cannot pickle 'ScandirIterator' object".to_owned())) + Err(vm.new_type_error("cannot pickle 'ScandirIterator' object")) } } @@ -1097,7 +1101,7 @@ pub(super) mod _os { impl Destructor for ScandirIteratorFd { fn del(zelf: &Py, vm: &VirtualMachine) -> PyResult<()> { if zelf.dir.lock().is_some() { - let _ = crate::stdlib::warnings::warn( + let _ = crate::stdlib::_warnings::warn( vm.ctx.exceptions.resource_warning, format!("unclosed scandir iterator {:?}", zelf.as_object()), 1, diff --git a/crates/vm/src/stdlib/posix.rs b/crates/vm/src/stdlib/posix.rs index 4cdb12f0d47..8cde18a47ba 100644 --- a/crates/vm/src/stdlib/posix.rs +++ b/crates/vm/src/stdlib/posix.rs @@ -767,9 +767,18 @@ pub mod module { // only for before_forkers, refer: test_register_at_fork in test_posix run_at_forkers(before_forkers, true, vm); + + #[cfg(feature = "threading")] + crate::stdlib::_imp::acquire_imp_lock_for_fork(); + + #[cfg(feature = "threading")] + vm.state.stop_the_world.stop_the_world(vm); } fn py_os_after_fork_child(vm: &VirtualMachine) { + #[cfg(feature = "threading")] + vm.state.stop_the_world.reset_after_fork(); + // Phase 1: Reset all internal locks FIRST. // After fork(), locks held by dead parent threads would deadlock // if we try to acquire them. This must happen before anything else. @@ -781,12 +790,12 @@ pub mod module { // held by dead parent threads, causing deadlocks on any IO in the child. #[cfg(feature = "threading")] unsafe { - crate::stdlib::io::reinit_std_streams_after_fork(vm) + crate::stdlib::_io::reinit_std_streams_after_fork(vm) }; // Phase 2: Reset low-level atomic state (no locks needed). crate::signal::clear_after_fork(); - crate::stdlib::signal::_signal::clear_wakeup_fd_after_fork(); + crate::stdlib::_signal::_signal::clear_wakeup_fd_after_fork(); // Reset weakref stripe locks that may have been held during fork. #[cfg(feature = "threading")] @@ -795,7 +804,14 @@ pub mod module { // Phase 3: Clean up thread state. Locks are now reinit'd so we can // acquire them normally instead of using try_lock(). #[cfg(feature = "threading")] - crate::stdlib::thread::after_fork_child(vm); + crate::stdlib::_thread::after_fork_child(vm); + + // CPython parity: reinit import lock ownership metadata in child + // and release the lock acquired by PyOS_BeforeFork(). + #[cfg(feature = "threading")] + unsafe { + crate::stdlib::_imp::after_fork_child_imp_lock_release() + }; // Initialize signal handlers for the child's main thread. // When forked from a worker thread, the OnceCell is empty. @@ -841,84 +857,186 @@ pub mod module { crate::gc_state::gc_state().reinit_after_fork(); // Import lock (RawReentrantMutex) - crate::stdlib::imp::reinit_imp_lock_after_fork(); + crate::stdlib::_imp::reinit_imp_lock_after_fork(); } } fn py_os_after_fork_parent(vm: &VirtualMachine) { + #[cfg(feature = "threading")] + vm.state.stop_the_world.start_the_world(vm); + + #[cfg(feature = "threading")] + crate::stdlib::_imp::release_imp_lock_after_fork_parent(); + let after_forkers_parent: Vec = vm.state.after_forkers_parent.lock().clone(); run_at_forkers(after_forkers_parent, false, vm); } - /// Warn if forking from a multi-threaded process - fn warn_if_multi_threaded(name: &str, vm: &VirtualMachine) { - // Only check threading if it was already imported - // Avoid vm.import() which can execute arbitrary Python code in the fork path - let threading = match vm - .sys_module - .get_attr("modules", vm) - .and_then(|m| m.get_item("threading", vm)) + /// Best-effort number of OS threads in this process. + /// Returns <= 0 when unavailable. + fn get_number_of_os_threads() -> isize { + #[cfg(target_os = "macos")] { - Ok(m) => m, - Err(_) => return, - }; - let active = threading.get_attr("_active", vm).ok(); - let limbo = threading.get_attr("_limbo", vm).ok(); + type MachPortT = libc::c_uint; + type KernReturnT = libc::c_int; + type MachMsgTypeNumberT = libc::c_uint; + type ThreadActArrayT = *mut MachPortT; + const KERN_SUCCESS: KernReturnT = 0; + unsafe extern "C" { + fn mach_task_self() -> MachPortT; + fn task_for_pid( + task: MachPortT, + pid: libc::c_int, + target_task: *mut MachPortT, + ) -> KernReturnT; + fn task_threads( + target_task: MachPortT, + act_list: *mut ThreadActArrayT, + act_list_cnt: *mut MachMsgTypeNumberT, + ) -> KernReturnT; + fn vm_deallocate( + target_task: MachPortT, + address: libc::uintptr_t, + size: libc::uintptr_t, + ) -> KernReturnT; + } - let count_dict = |obj: Option| -> usize { - obj.and_then(|o| o.length_opt(vm)) - .and_then(|r| r.ok()) - .unwrap_or(0) - }; + let self_task = unsafe { mach_task_self() }; + let mut proc_task: MachPortT = 0; + if unsafe { task_for_pid(self_task, libc::getpid(), &mut proc_task) } == KERN_SUCCESS { + let mut threads: ThreadActArrayT = core::ptr::null_mut(); + let mut n_threads: MachMsgTypeNumberT = 0; + if unsafe { task_threads(proc_task, &mut threads, &mut n_threads) } == KERN_SUCCESS + { + if !threads.is_null() { + let _ = unsafe { + vm_deallocate( + self_task, + threads as libc::uintptr_t, + (n_threads as usize * core::mem::size_of::()) + as libc::uintptr_t, + ) + }; + } + return n_threads as isize; + } + } + 0 + } + #[cfg(target_os = "linux")] + { + use std::io::Read as _; + let mut file = match std::fs::File::open("/proc/self/stat") { + Ok(f) => f, + Err(_) => return 0, + }; + let mut buf = [0u8; 160]; + let n = match file.read(&mut buf) { + Ok(n) => n, + Err(_) => return 0, + }; + let line = match core::str::from_utf8(&buf[..n]) { + Ok(s) => s, + Err(_) => return 0, + }; + if let Some(field) = line.split_whitespace().nth(19) { + return field.parse::().unwrap_or(0); + } + 0 + } + #[cfg(not(any(target_os = "macos", target_os = "linux")))] + { + 0 + } + } - let num_threads = count_dict(active) + count_dict(limbo); - if num_threads > 1 { - // Use Python warnings module to ensure filters are applied correctly - let Ok(warnings) = vm.import("warnings", 0) else { - return; + /// Warn if forking from a multi-threaded process. + /// `num_os_threads` should be captured before parent after-fork hooks run. + fn warn_if_multi_threaded(name: &str, num_os_threads: isize, vm: &VirtualMachine) { + let num_threads = if num_os_threads > 0 { + num_os_threads as usize + } else { + // CPython fallback: if OS-level count isn't available, use the + // threading module's active+limbo view. + // Only check threading if it was already imported. Avoid vm.import() + // which can execute arbitrary Python code in the fork path. + let threading = match vm + .sys_module + .get_attr("modules", vm) + .and_then(|m| m.get_item("threading", vm)) + { + Ok(m) => m, + Err(_) => return, }; - let Ok(warn_fn) = warnings.get_attr("warn", vm) else { - return; + let active = threading.get_attr("_active", vm).ok(); + let limbo = threading.get_attr("_limbo", vm).ok(); + + // Match threading module internals and avoid sequence overcounting: + // count only dict-backed _active/_limbo containers. + let count_dict = |obj: Option| -> usize { + obj.and_then(|o| { + o.downcast_ref::() + .map(|d| d.__len__()) + }) + .unwrap_or(0) }; + count_dict(active) + count_dict(limbo) + }; + + if num_threads > 1 { let pid = unsafe { libc::getpid() }; let msg = format!( "This process (pid={}) is multi-threaded, use of {}() may lead to deadlocks in the child.", pid, name ); - // Call warnings.warn(message, DeprecationWarning, stacklevel=2) - // stacklevel=2 to point to the caller of fork() - let args = crate::function::FuncArgs::new( - vec![ - vm.ctx.new_str(msg).into(), - vm.ctx.exceptions.deprecation_warning.as_object().to_owned(), - ], - crate::function::KwArgs::new( - [("stacklevel".to_owned(), vm.ctx.new_int(2).into())] - .into_iter() - .collect(), - ), - ); - let _ = warn_fn.call(args, vm); + // Match PyErr_WarnFormat(..., stacklevel=1) in CPython. + // Best effort: ignore failures like CPython does in this path. + let _ = + crate::stdlib::_warnings::warn(vm.ctx.exceptions.deprecation_warning, msg, 1, vm); } } #[pyfunction] - fn fork(vm: &VirtualMachine) -> i32 { - warn_if_multi_threaded("fork", vm); + fn fork(vm: &VirtualMachine) -> PyResult { + if vm + .state + .finalizing + .load(core::sync::atomic::Ordering::Acquire) + { + return Err(vm.new_exception_msg( + vm.ctx.exceptions.python_finalization_error.to_owned(), + "can't fork at interpreter shutdown".into(), + )); + } + + // RustPython does not yet have C-level audit hooks; call sys.audit() + // to preserve Python-visible behavior and failure semantics. + vm.sys_module + .get_attr("audit", vm)? + .call(("os.fork",), vm)?; - let pid: i32; py_os_before_fork(vm); - unsafe { - pid = libc::fork(); - } + + let pid = unsafe { libc::fork() }; + // Save errno immediately — AfterFork callbacks may clobber it. + let saved_errno = nix::Error::last_raw(); if pid == 0 { py_os_after_fork_child(vm); } else { + // Match CPython timing: capture this before parent after-fork hooks + // in case those hooks start threads. + let num_os_threads = get_number_of_os_threads(); py_os_after_fork_parent(vm); + // Match CPython timing: warn only after parent callback path resumes world. + warn_if_multi_threaded("fork", num_os_threads, vm); + } + if pid == -1 { + Err(nix::Error::from_raw(saved_errno).into_pyexception(vm)) + } else { + Ok(pid) } - pid } #[cfg(not(target_os = "redox"))] @@ -1835,13 +1953,18 @@ pub mod module { fn waitpid(pid: libc::pid_t, opt: i32, vm: &VirtualMachine) -> PyResult<(libc::pid_t, i32)> { let mut status = 0; loop { - let res = unsafe { libc::waitpid(pid, &mut status, opt) }; + // Capture errno inside the closure: attach_thread (called by + // allow_threads on return) can clobber errno via syscalls. + let (res, err) = vm.allow_threads(|| { + let r = unsafe { libc::waitpid(pid, &mut status, opt) }; + (r, nix::Error::last_raw()) + }); if res == -1 { - if nix::Error::last_raw() == libc::EINTR { + if err == libc::EINTR { vm.check_signals()?; continue; } - return Err(nix::Error::last().into_pyexception(vm)); + return Err(nix::Error::from_raw(err).into_pyexception(vm)); } return Ok((res, status)); } @@ -2061,9 +2184,7 @@ pub mod module { Ok(int) => int.try_to_primitive(vm)?, Err(obj) => { let s = obj.downcast::().map_err(|_| { - vm.new_type_error( - "configuration names must be strings or integers".to_owned(), - ) + vm.new_type_error("configuration names must be strings or integers") })?; s.as_str() .parse::() @@ -2456,9 +2577,7 @@ pub mod module { Ok(int) => int.try_to_primitive(vm)?, Err(obj) => { let s = obj.downcast::().map_err(|_| { - vm.new_type_error( - "configuration names must be strings or integers".to_owned(), - ) + vm.new_type_error("configuration names must be strings or integers") })?; { let name = s.as_str(); @@ -2704,7 +2823,7 @@ mod posix_sched { class::StaticType, }; if !obj.fast_isinstance(PySchedParam::static_type()) { - return Err(vm.new_type_error("must have a sched_param object".to_owned())); + return Err(vm.new_type_error("must have a sched_param object")); } let tuple = obj.downcast_ref::().unwrap(); let priority = tuple[0].clone(); diff --git a/crates/vm/src/stdlib/sys.rs b/crates/vm/src/stdlib/sys.rs index 68734f94631..33325c9dc60 100644 --- a/crates/vm/src/stdlib/sys.rs +++ b/crates/vm/src/stdlib/sys.rs @@ -45,7 +45,7 @@ mod sys { convert::ToPyObject, frame::{Frame, FrameRef}, function::{FuncArgs, KwArgs, OptionalArg, PosArgs}, - stdlib::{builtins, warnings::warn}, + stdlib::{_warnings::warn, builtins}, types::PyStructSequence, version, vm::{Settings, VirtualMachine}, @@ -1037,7 +1037,7 @@ mod sys { #[pyfunction] fn _current_frames(vm: &VirtualMachine) -> PyResult { use crate::AsObject; - use crate::stdlib::thread::get_all_current_frames; + use crate::stdlib::_thread::get_all_current_frames; let frames = get_all_current_frames(vm); let dict = vm.ctx.new_dict(); @@ -1430,7 +1430,7 @@ mod sys { // Check if type is immutable if type_obj.slots.flags.has_feature(PyTypeFlags::IMMUTABLETYPE) { - return Err(vm.new_type_error("argument is immutable".to_owned())); + return Err(vm.new_type_error("argument is immutable")); } let mut attributes = type_obj.attributes.write(); @@ -1633,7 +1633,7 @@ mod sys { #[cfg(feature = "threading")] impl ThreadInfoData { const INFO: Self = Self { - name: crate::stdlib::thread::_thread::PYTHREAD_NAME, + name: crate::stdlib::_thread::_thread::PYTHREAD_NAME, // As I know, there's only way to use lock as "Mutex" in Rust // with satisfying python document spec. lock: Some("mutex+cond"), diff --git a/crates/vm/src/stdlib/sys/monitoring.rs b/crates/vm/src/stdlib/sys/monitoring.rs index 6d1aeb9c8f3..739165073af 100644 --- a/crates/vm/src/stdlib/sys/monitoring.rs +++ b/crates/vm/src/stdlib/sys/monitoring.rs @@ -777,7 +777,7 @@ fn fire( // Non-local events (RAISE, EXCEPTION_HANDLED, PY_UNWIND, etc.) // cannot be disabled per code object. if event_id >= LOCAL_EVENTS_COUNT { - // Remove the callback, matching CPython behavior. + // Remove the callback. let mut state = vm.state.monitoring.lock(); state.callbacks.remove(&(tool, event_id)); return Err(vm.new_value_error(format!( diff --git a/crates/vm/src/stdlib/time.rs b/crates/vm/src/stdlib/time.rs index 3477648c4a7..d38152db84a 100644 --- a/crates/vm/src/stdlib/time.rs +++ b/crates/vm/src/stdlib/time.rs @@ -117,8 +117,13 @@ mod decl { { // this is basically std::thread::sleep, but that catches interrupts and we don't want to; let ts = nix::sys::time::TimeSpec::from(dur); - let res = unsafe { libc::nanosleep(ts.as_ref(), core::ptr::null_mut()) }; - let interrupted = res == -1 && nix::Error::last_raw() == libc::EINTR; + // Capture errno inside the closure: attach_thread (called by + // allow_threads on return) can clobber errno via syscalls. + let (res, err) = vm.allow_threads(|| { + let r = unsafe { libc::nanosleep(ts.as_ref(), core::ptr::null_mut()) }; + (r, nix::Error::last_raw()) + }); + let interrupted = res == -1 && err == libc::EINTR; if interrupted { vm.check_signals()?; @@ -127,7 +132,7 @@ mod decl { #[cfg(not(unix))] { - std::thread::sleep(dur); + vm.allow_threads(|| std::thread::sleep(dur)); } Ok(()) @@ -630,7 +635,7 @@ mod decl { { let year = tm.tm_year + 1900; if !(1..=9999).contains(&year) { - return Err(vm.new_value_error("strftime() requires year in [1; 9999]".to_owned())); + return Err(vm.new_value_error("strftime() requires year in [1; 9999]")); } } diff --git a/crates/vm/src/stdlib/typevar.rs b/crates/vm/src/stdlib/typevar.rs index d0bd3f5666d..b28fad21bd7 100644 --- a/crates/vm/src/stdlib/typevar.rs +++ b/crates/vm/src/stdlib/typevar.rs @@ -10,7 +10,7 @@ pub(crate) mod typevar { common::lock::PyMutex, function::{FuncArgs, PyComparisonValue}, protocol::PyNumberMethods, - stdlib::typing::{call_typing_func_object, decl::const_evaluator_alloc}, + stdlib::_typing::{call_typing_func_object, decl::const_evaluator_alloc}, types::{AsNumber, Comparable, Constructor, Iterable, PyComparisonOp, Representable}, }; @@ -94,7 +94,10 @@ pub(crate) mod typevar { contravariant: bool, infer_variance: bool, } - #[pyclass(flags(HAS_DICT), with(AsNumber, Constructor, Representable))] + #[pyclass( + flags(HAS_DICT, HAS_WEAKREF), + with(AsNumber, Constructor, Representable) + )] impl TypeVar { #[pymethod] fn __mro_entries__(&self, _bases: PyObjectRef, vm: &VirtualMachine) -> PyResult { @@ -461,7 +464,10 @@ pub(crate) mod typevar { infer_variance: bool, } - #[pyclass(flags(HAS_DICT), with(AsNumber, Constructor, Representable))] + #[pyclass( + flags(HAS_DICT, HAS_WEAKREF), + with(AsNumber, Constructor, Representable) + )] impl ParamSpec { #[pymethod] fn __mro_entries__(&self, _bases: PyObjectRef, vm: &VirtualMachine) -> PyResult { @@ -713,7 +719,10 @@ pub(crate) mod typevar { default_value: PyMutex, evaluate_default: PyMutex, } - #[pyclass(flags(HAS_DICT), with(Constructor, Representable, Iterable))] + #[pyclass( + flags(HAS_DICT, HAS_WEAKREF), + with(Constructor, Representable, Iterable) + )] impl TypeVarTuple { #[pygetset] fn __name__(&self) -> PyObjectRef { @@ -883,7 +892,7 @@ pub(crate) mod typevar { pub struct ParamSpecArgs { __origin__: PyObjectRef, } - #[pyclass(with(Constructor, Representable, Comparable))] + #[pyclass(with(Constructor, Representable, Comparable), flags(HAS_WEAKREF))] impl ParamSpecArgs { #[pymethod] fn __mro_entries__(&self, _bases: PyObjectRef, vm: &VirtualMachine) -> PyResult { @@ -946,7 +955,7 @@ pub(crate) mod typevar { pub struct ParamSpecKwargs { __origin__: PyObjectRef, } - #[pyclass(with(Constructor, Representable, Comparable))] + #[pyclass(with(Constructor, Representable, Comparable), flags(HAS_WEAKREF))] impl ParamSpecKwargs { #[pymethod] fn __mro_entries__(&self, _bases: PyObjectRef, vm: &VirtualMachine) -> PyResult { diff --git a/crates/vm/src/stdlib/winreg.rs b/crates/vm/src/stdlib/winreg.rs index 026d8d38c63..264d14327da 100644 --- a/crates/vm/src/stdlib/winreg.rs +++ b/crates/vm/src/stdlib/winreg.rs @@ -958,15 +958,15 @@ mod winreg { } let val = value .downcast_ref::() - .ok_or_else(|| vm.new_type_error("value must be an integer".to_string()))?; + .ok_or_else(|| vm.new_type_error("value must be an integer"))?; let bigint = val.as_bigint(); // Check for negative value - raise OverflowError if bigint.sign() == Sign::Minus { - return Err(vm.new_overflow_error("int too big to convert".to_string())); + return Err(vm.new_overflow_error("int too big to convert")); } let val = bigint .to_u32() - .ok_or_else(|| vm.new_overflow_error("int too big to convert".to_string()))?; + .ok_or_else(|| vm.new_overflow_error("int too big to convert"))?; Ok(Some(val.to_le_bytes().to_vec())) } REG_QWORD => { @@ -975,15 +975,15 @@ mod winreg { } let val = value .downcast_ref::() - .ok_or_else(|| vm.new_type_error("value must be an integer".to_string()))?; + .ok_or_else(|| vm.new_type_error("value must be an integer"))?; let bigint = val.as_bigint(); // Check for negative value - raise OverflowError if bigint.sign() == Sign::Minus { - return Err(vm.new_overflow_error("int too big to convert".to_string())); + return Err(vm.new_overflow_error("int too big to convert")); } let val = bigint .to_u64() - .ok_or_else(|| vm.new_overflow_error("int too big to convert".to_string()))?; + .ok_or_else(|| vm.new_overflow_error("int too big to convert"))?; Ok(Some(val.to_le_bytes().to_vec())) } REG_SZ | REG_EXPAND_SZ => { @@ -993,7 +993,7 @@ mod winreg { } let s = value .downcast::() - .map_err(|_| vm.new_type_error("value must be a string".to_string()))?; + .map_err(|_| vm.new_type_error("value must be a string"))?; let wide = s.as_wtf8().to_wide_with_nul(); // Convert Vec to Vec let bytes: Vec = wide.iter().flat_map(|&c| c.to_le_bytes()).collect(); @@ -1004,15 +1004,15 @@ mod winreg { // Empty list = double null terminator return Ok(Some(vec![0u8, 0u8, 0u8, 0u8])); } - let list = value.downcast::().map_err(|_| { - vm.new_type_error("value must be a list of strings".to_string()) - })?; + let list = value + .downcast::() + .map_err(|_| vm.new_type_error("value must be a list of strings"))?; let mut bytes: Vec = Vec::new(); for item in list.borrow_vec().iter() { - let s = item.downcast_ref::().ok_or_else(|| { - vm.new_type_error("list items must be strings".to_string()) - })?; + let s = item + .downcast_ref::() + .ok_or_else(|| vm.new_type_error("list items must be strings"))?; let wide = s.as_wtf8().to_wide_with_nul(); bytes.extend(wide.iter().flat_map(|&c| c.to_le_bytes())); } diff --git a/crates/vm/src/stdlib/winsound.rs b/crates/vm/src/stdlib/winsound.rs index 729305f879d..0ca2e9a2258 100644 --- a/crates/vm/src/stdlib/winsound.rs +++ b/crates/vm/src/stdlib/winsound.rs @@ -87,24 +87,22 @@ mod winsound { if vm.is_none(&sound) { let ok = unsafe { super::win32::PlaySoundW(core::ptr::null(), 0, flags) }; if ok == 0 { - return Err(vm.new_runtime_error("Failed to play sound".to_owned())); + return Err(vm.new_runtime_error("Failed to play sound")); } return Ok(()); } if flags & SND_MEMORY != 0 { if flags & SND_ASYNC != 0 { - return Err( - vm.new_runtime_error("Cannot play asynchronously from memory".to_owned()) - ); + return Err(vm.new_runtime_error("Cannot play asynchronously from memory")); } let buffer = PyBuffer::try_from_borrowed_object(vm, &sound)?; - let buf = buffer.as_contiguous().ok_or_else(|| { - vm.new_type_error("a bytes-like object is required, not 'str'".to_owned()) - })?; + let buf = buffer + .as_contiguous() + .ok_or_else(|| vm.new_type_error("a bytes-like object is required, not 'str'"))?; let ok = unsafe { super::win32::PlaySoundW(buf.as_ptr() as *const u16, 0, flags) }; if ok == 0 { - return Err(vm.new_runtime_error("Failed to play sound".to_owned())); + return Err(vm.new_runtime_error("Failed to play sound")); } return Ok(()); } @@ -138,9 +136,7 @@ mod winsound { let result = fspath.call((), vm)?; if result.downcastable::() { - return Err( - vm.new_type_error("'sound' must resolve to str, not bytes".to_owned()) - ); + return Err(vm.new_type_error("'sound' must resolve to str, not bytes")); } let s: &PyStr = result.downcast_ref().ok_or_else(|| { @@ -157,13 +153,13 @@ mod winsound { // Check for embedded null characters if path.as_bytes().contains(&0) { - return Err(vm.new_value_error("embedded null character".to_owned())); + return Err(vm.new_value_error("embedded null character")); } let wide = path.to_wide_with_nul(); let ok = unsafe { super::win32::PlaySoundW(wide.as_ptr(), 0, flags) }; if ok == 0 { - return Err(vm.new_runtime_error("Failed to play sound".to_owned())); + return Err(vm.new_runtime_error("Failed to play sound")); } Ok(()) } @@ -179,12 +175,12 @@ mod winsound { #[pyfunction] fn Beep(args: BeepArgs, vm: &VirtualMachine) -> PyResult<()> { if !(37..=32767).contains(&args.frequency) { - return Err(vm.new_value_error("frequency must be in 37 thru 32767".to_owned())); + return Err(vm.new_value_error("frequency must be in 37 thru 32767")); } let ok = unsafe { super::win32::Beep(args.frequency as u32, args.duration as u32) }; if ok == 0 { - return Err(vm.new_runtime_error("Failed to beep".to_owned())); + return Err(vm.new_runtime_error("Failed to beep")); } Ok(()) } diff --git a/crates/vm/src/types/slot.rs b/crates/vm/src/types/slot.rs index 58040f7928c..222d827c7f5 100644 --- a/crates/vm/src/types/slot.rs +++ b/crates/vm/src/types/slot.rs @@ -174,6 +174,7 @@ pub struct PyTypeSlots { // tp_dictoffset pub init: AtomicCell>, // tp_alloc + pub alloc: AtomicCell>, pub new: AtomicCell>, // tp_free // tp_is_gc @@ -215,6 +216,7 @@ bitflags! { #[derive(Copy, Clone, Debug, PartialEq)] #[non_exhaustive] pub struct PyTypeFlags: u64 { + const MANAGED_WEAKREF = 1 << 3; const MANAGED_DICT = 1 << 4; const SEQUENCE = 1 << 5; const MAPPING = 1 << 6; @@ -228,6 +230,7 @@ bitflags! { // This is not a stable API const _MATCH_SELF = 1 << 22; const HAS_DICT = 1 << 40; + const HAS_WEAKREF = 1 << 41; #[cfg(debug_assertions)] const _CREATED_WITH_FLAGS = 1 << 63; @@ -297,6 +300,7 @@ pub(crate) type DescrGetFunc = fn(PyObjectRef, Option, Option, &VirtualMachine) -> PyResult; pub(crate) type DescrSetFunc = fn(&PyObject, PyObjectRef, PySetterValue, &VirtualMachine) -> PyResult<()>; +pub(crate) type AllocFunc = fn(PyTypeRef, usize, &VirtualMachine) -> PyResult; pub(crate) type NewFunc = fn(PyTypeRef, FuncArgs, &VirtualMachine) -> PyResult; pub(crate) type InitFunc = fn(PyObjectRef, FuncArgs, &VirtualMachine) -> PyResult<()>; pub(crate) type DelFunc = fn(&PyObject, &VirtualMachine) -> PyResult<()>; @@ -610,7 +614,7 @@ fn init_wrapper(obj: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResu let res = vm.call_special_method(&obj, identifier!(vm, __init__), args)?; if !vm.is_none(&res) { return Err(vm.new_type_error(format!( - "__init__ should return None, not '{:.200}'", + "__init__() should return None, not '{:.200}'", res.class().name() ))); } diff --git a/crates/vm/src/types/structseq.rs b/crates/vm/src/types/structseq.rs index 0ac73c0fc19..0744d7a4a00 100644 --- a/crates/vm/src/types/structseq.rs +++ b/crates/vm/src/types/structseq.rs @@ -251,7 +251,7 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static { #[pymethod] fn __replace__(zelf: PyRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { if !args.args.is_empty() { - return Err(vm.new_type_error("__replace__() takes no positional arguments".to_owned())); + return Err(vm.new_type_error("__replace__() takes no positional arguments")); } if Self::Data::UNNAMED_FIELDS_LEN > 0 { diff --git a/crates/vm/src/types/zoo.rs b/crates/vm/src/types/zoo.rs index a9999211680..0394f672cf8 100644 --- a/crates/vm/src/types/zoo.rs +++ b/crates/vm/src/types/zoo.rs @@ -198,7 +198,7 @@ impl TypeZoo { weakproxy_type: weakproxy::PyWeakProxy::init_builtin_type(), method_descriptor_type: descriptor::PyMethodDescriptor::init_builtin_type(), none_type: singletons::PyNone::init_builtin_type(), - typing_no_default_type: crate::stdlib::typing::NoDefault::init_builtin_type(), + typing_no_default_type: crate::stdlib::_typing::NoDefault::init_builtin_type(), not_implemented_type: singletons::PyNotImplemented::init_builtin_type(), generic_alias_type: genericalias::PyGenericAlias::init_builtin_type(), generic_alias_iterator_type: genericalias::PyGenericAliasIterator::init_builtin_type(), @@ -265,6 +265,6 @@ impl TypeZoo { interpolation::init(context); template::init(context); descriptor::init(context); - crate::stdlib::typing::init(context); + crate::stdlib::_typing::init(context); } } diff --git a/crates/vm/src/vm/context.rs b/crates/vm/src/vm/context.rs index d864548ff08..1edb8656e0f 100644 --- a/crates/vm/src/vm/context.rs +++ b/crates/vm/src/vm/context.rs @@ -14,6 +14,7 @@ use crate::{ object, pystr, type_::PyAttributes, }, + bytecode::{self, CodeFlags, CodeUnit, Instruction}, class::StaticType, common::rc::PyRc, exceptions, @@ -29,6 +30,7 @@ use malachite_bigint::BigInt; use num_complex::Complex64; use num_traits::ToPrimitive; use rustpython_common::lock::PyRwLock; +use rustpython_compiler_core::{OneIndexed, SourceLocation}; #[derive(Debug)] pub struct Context { @@ -42,11 +44,14 @@ pub struct Context { pub ellipsis: PyRef, pub not_implemented: PyRef, - pub typing_no_default: PyRef, + pub typing_no_default: PyRef, pub types: TypeZoo, pub exceptions: exceptions::ExceptionZoo, pub int_cache_pool: Vec, + pub(crate) latin1_char_cache: Vec>, + pub(crate) ascii_char_cache: Vec>, + pub(crate) init_cleanup_code: PyRef, // there should only be exact objects of str in here, no non-str objects and no subclasses pub(crate) string_pool: StringPool, pub(crate) slot_new_wrapper: PyMethodDef, @@ -311,8 +316,8 @@ impl Context { let not_implemented = create_object(PyNotImplemented, PyNotImplemented::static_type()); let typing_no_default = create_object( - crate::stdlib::typing::NoDefault, - crate::stdlib::typing::NoDefault::static_type(), + crate::stdlib::_typing::NoDefault, + crate::stdlib::_typing::NoDefault::static_type(), ); let int_cache_pool = Self::INT_CACHE_POOL_RANGE @@ -324,6 +329,10 @@ impl Context { ) }) .collect(); + let latin1_char_cache: Vec> = (0u8..=255) + .map(|b| create_object(PyStr::from(char::from(b)), types.str_type)) + .collect(); + let ascii_char_cache = latin1_char_cache[..128].to_vec(); let true_value = create_object(PyBool(PyInt::from(1)), types.bool_type); let false_value = create_object(PyBool(PyInt::from(0)), types.bool_type); @@ -347,6 +356,7 @@ impl Context { PyMethodFlags::METHOD, None, ); + let init_cleanup_code = Self::new_init_cleanup_code(&types, &names); let empty_str = unsafe { string_pool.intern("", types.str_type.to_owned()) }; let empty_bytes = create_object(PyBytes::from(Vec::new()), types.bytes_type); @@ -371,6 +381,9 @@ impl Context { types, exceptions, int_cache_pool, + latin1_char_cache, + ascii_char_cache, + init_cleanup_code, string_pool, slot_new_wrapper, names, @@ -380,6 +393,51 @@ impl Context { } } + fn new_init_cleanup_code(types: &TypeZoo, names: &ConstName) -> PyRef { + let loc = SourceLocation { + line: OneIndexed::MIN, + character_offset: OneIndexed::from_zero_indexed(0), + }; + let instructions = [ + CodeUnit { + op: Instruction::ExitInitCheck, + arg: 0.into(), + }, + CodeUnit { + op: Instruction::ReturnValue, + arg: 0.into(), + }, + CodeUnit { + op: Instruction::Resume { + context: bytecode::Arg::marker(), + }, + arg: 0.into(), + }, + ]; + let code = bytecode::CodeObject { + instructions: instructions.into(), + locations: vec![(loc, loc); instructions.len()].into_boxed_slice(), + flags: CodeFlags::OPTIMIZED, + posonlyarg_count: 0, + arg_count: 0, + kwonlyarg_count: 0, + source_path: names.__init__, + first_line_number: None, + max_stackdepth: 2, + obj_name: names.__init__, + qualname: names.__init__, + cell2arg: None, + constants: core::iter::empty().collect(), + names: Vec::new().into_boxed_slice(), + varnames: Vec::new().into_boxed_slice(), + cellvars: Vec::new().into_boxed_slice(), + freevars: Vec::new().into_boxed_slice(), + linetable: Vec::new().into_boxed_slice(), + exceptiontable: Vec::new().into_boxed_slice(), + }; + PyRef::new_ref(PyCode::new(code), types.code_type.to_owned(), None) + } + pub fn intern_str(&self, s: S) -> &'static PyStrInterned { unsafe { self.string_pool.intern(s, self.types.str_type.to_owned()) } } @@ -450,9 +508,28 @@ impl Context { PyComplex::from(value).into_ref(self) } + #[inline] + pub fn latin1_char(&self, ch: u8) -> PyRef { + self.latin1_char_cache[ch as usize].clone() + } + + #[inline] + fn latin1_singleton_index(s: &PyStr) -> Option { + let mut cps = s.as_wtf8().code_points(); + let cp = cps.next()?; + if cps.next().is_some() { + return None; + } + u8::try_from(cp.to_u32()).ok() + } + #[inline] pub fn new_str(&self, s: impl Into) -> PyRef { - s.into().into_ref(self) + let s = s.into(); + if let Some(ch) = Self::latin1_singleton_index(&s) { + return self.latin1_char(ch); + } + s.into_ref(self) } #[inline] diff --git a/crates/vm/src/vm/interpreter.rs b/crates/vm/src/vm/interpreter.rs index 8e275d1ce9e..5bf7436e958 100644 --- a/crates/vm/src/vm/interpreter.rs +++ b/crates/vm/src/vm/interpreter.rs @@ -1,3 +1,5 @@ +#[cfg(all(unix, feature = "threading"))] +use super::StopTheWorldState; use super::{Context, PyConfig, PyGlobalState, VirtualMachine, setting::Settings, thread}; use crate::{ PyResult, builtins, common::rc::PyRc, frozen::FrozenModule, getpath, py_freeze, stdlib::atexit, @@ -124,6 +126,8 @@ where monitoring: PyMutex::default(), monitoring_events: AtomicCell::new(0), instrumentation_version: AtomicU64::new(0), + #[cfg(all(unix, feature = "threading"))] + stop_the_world: StopTheWorldState::new(), }); // Create VM with the global state @@ -470,8 +474,10 @@ fn core_frozen_inits() -> impl Iterator { crate_name = "rustpython_compiler_core" ); - // Collect and add frozen module aliases for test modules + // Collect frozen module entries let mut entries: Vec<_> = iter.collect(); + + // Add test module aliases if let Some(hello_code) = entries .iter() .find(|(n, _)| *n == "__hello__") diff --git a/crates/vm/src/vm/mod.rs b/crates/vm/src/vm/mod.rs index 6461502a582..05210eb09d7 100644 --- a/crates/vm/src/vm/mod.rs +++ b/crates/vm/src/vm/mod.rs @@ -40,6 +40,8 @@ use crate::{ warn::WarningsState, }; use alloc::{borrow::Cow, collections::BTreeMap}; +#[cfg(all(unix, feature = "threading"))] +use core::sync::atomic::AtomicI64; use core::{ cell::{Cell, OnceCell, RefCell}, ptr::NonNull, @@ -92,6 +94,7 @@ pub struct VirtualMachine { pub initialized: bool, recursion_depth: Cell, /// C stack soft limit for detecting stack overflow (like c_stack_soft_limit) + #[cfg_attr(miri, allow(dead_code))] c_stack_soft_limit: Cell, /// Async generator firstiter hook (per-thread, set via sys.set_asyncgen_hooks) pub async_gen_firstiter: RefCell>, @@ -101,6 +104,7 @@ pub struct VirtualMachine { pub asyncio_running_loop: RefCell>, /// Current running asyncio task for this thread pub asyncio_running_task: RefCell>, + pub(crate) callable_cache: CallableCache, } /// Non-owning frame pointer for the frames stack. @@ -125,6 +129,455 @@ struct ExceptionStack { stack: Vec>, } +/// Stop-the-world state for fork safety. Before `fork()`, the requester +/// stops all other Python threads so they are not holding internal locks. +#[cfg(all(unix, feature = "threading"))] +pub struct StopTheWorldState { + /// Fast-path flag checked in the bytecode loop (like `_PY_EVAL_PLEASE_STOP_BIT`) + pub(crate) requested: AtomicBool, + /// Whether the world is currently stopped (`stw->world_stopped`). + world_stopped: AtomicBool, + /// Ident of the thread that requested the stop (like `stw->requester`) + requester: AtomicU64, + /// Signaled by suspending threads when their state transitions to SUSPENDED + notify_mutex: std::sync::Mutex<()>, + notify_cv: std::sync::Condvar, + /// Number of non-requester threads still expected to park for current stop request. + thread_countdown: AtomicI64, + /// Number of stop-the-world attempts. + stats_stop_calls: AtomicU64, + /// Most recent stop-the-world wait duration in ns. + stats_last_wait_ns: AtomicU64, + /// Total accumulated stop-the-world wait duration in ns. + stats_total_wait_ns: AtomicU64, + /// Max observed stop-the-world wait duration in ns. + stats_max_wait_ns: AtomicU64, + /// Number of poll-loop iterations spent waiting. + stats_poll_loops: AtomicU64, + /// Number of ATTACHED threads observed while polling. + stats_attached_seen: AtomicU64, + /// Number of DETACHED->SUSPENDED parks requested by requester. + stats_forced_parks: AtomicU64, + /// Number of suspend notifications from worker threads. + stats_suspend_notifications: AtomicU64, + /// Number of yield loops while attach waited on SUSPENDED->DETACHED. + stats_attach_wait_yields: AtomicU64, + /// Number of yield loops while suspend waited on SUSPENDED->DETACHED. + stats_suspend_wait_yields: AtomicU64, +} + +#[cfg(all(unix, feature = "threading"))] +#[derive(Debug, Clone, Copy)] +pub struct StopTheWorldStats { + pub stop_calls: u64, + pub last_wait_ns: u64, + pub total_wait_ns: u64, + pub max_wait_ns: u64, + pub poll_loops: u64, + pub attached_seen: u64, + pub forced_parks: u64, + pub suspend_notifications: u64, + pub attach_wait_yields: u64, + pub suspend_wait_yields: u64, + pub world_stopped: bool, +} + +#[cfg(all(unix, feature = "threading"))] +impl Default for StopTheWorldState { + fn default() -> Self { + Self::new() + } +} + +#[cfg(all(unix, feature = "threading"))] +impl StopTheWorldState { + pub const fn new() -> Self { + Self { + requested: AtomicBool::new(false), + world_stopped: AtomicBool::new(false), + requester: AtomicU64::new(0), + notify_mutex: std::sync::Mutex::new(()), + notify_cv: std::sync::Condvar::new(), + thread_countdown: AtomicI64::new(0), + stats_stop_calls: AtomicU64::new(0), + stats_last_wait_ns: AtomicU64::new(0), + stats_total_wait_ns: AtomicU64::new(0), + stats_max_wait_ns: AtomicU64::new(0), + stats_poll_loops: AtomicU64::new(0), + stats_attached_seen: AtomicU64::new(0), + stats_forced_parks: AtomicU64::new(0), + stats_suspend_notifications: AtomicU64::new(0), + stats_attach_wait_yields: AtomicU64::new(0), + stats_suspend_wait_yields: AtomicU64::new(0), + } + } + + /// Wake the stop-the-world requester (called by each thread that suspends). + pub(crate) fn notify_suspended(&self) { + self.stats_suspend_notifications + .fetch_add(1, Ordering::Relaxed); + // Synchronize with requester wait loop to avoid lost wakeups. + let _guard = self.notify_mutex.lock().unwrap(); + self.decrement_thread_countdown(1); + self.notify_cv.notify_one(); + } + + #[inline] + fn init_thread_countdown(&self, vm: &VirtualMachine) -> i64 { + let requester = self.requester.load(Ordering::Relaxed); + let registry = vm.state.thread_frames.lock(); + // Keep requested/count initialization serialized with thread-slot + // registration (which also takes this lock), matching the + // HEAD_LOCK-guarded stop-the-world bookkeeping. + self.requested.store(true, Ordering::Release); + let count = registry + .keys() + .filter(|&&thread_id| thread_id != requester) + .count(); + let count = (count.min(i64::MAX as usize)) as i64; + self.thread_countdown.store(count, Ordering::Release); + count + } + + #[inline] + fn decrement_thread_countdown(&self, n: u64) { + if n == 0 { + return; + } + let n = (n.min(i64::MAX as u64)) as i64; + let prev = self.thread_countdown.fetch_sub(n, Ordering::AcqRel); + if prev <= n { + // Clamp at 0 for safety in case of duplicate notifications. + self.thread_countdown.store(0, Ordering::Release); + } + } + + /// Try to CAS detached threads directly to SUSPENDED and check whether + /// stop countdown reached zero after parking detached threads. + fn park_detached_threads(&self, vm: &VirtualMachine) -> bool { + use thread::{THREAD_ATTACHED, THREAD_DETACHED, THREAD_SUSPENDED}; + let requester = self.requester.load(Ordering::Relaxed); + let registry = vm.state.thread_frames.lock(); + let mut attached_seen = 0u64; + let mut forced_parks = 0u64; + for (&id, slot) in registry.iter() { + if id == requester { + continue; + } + let state = slot.state.load(Ordering::Relaxed); + if state == THREAD_DETACHED { + // CAS DETACHED → SUSPENDED (park without thread cooperation) + match slot.state.compare_exchange( + THREAD_DETACHED, + THREAD_SUSPENDED, + Ordering::AcqRel, + Ordering::Relaxed, + ) { + Ok(_) => { + slot.stop_requested.store(false, Ordering::Release); + forced_parks = forced_parks.saturating_add(1); + } + Err(THREAD_ATTACHED) => { + // Set per-thread stop bit (_PY_EVAL_PLEASE_STOP_BIT). + slot.stop_requested.store(true, Ordering::Release); + // Raced with a thread re-attaching; it will self-suspend. + attached_seen = attached_seen.saturating_add(1); + } + Err(THREAD_DETACHED) => { + // Extremely unlikely race; next poll will handle it. + } + Err(THREAD_SUSPENDED) => { + slot.stop_requested.store(false, Ordering::Release); + // Another path parked it first. + } + Err(other) => { + debug_assert!( + false, + "unexpected thread state in park_detached_threads: {other}" + ); + } + } + } else if state == THREAD_ATTACHED { + // Set per-thread stop bit (_PY_EVAL_PLEASE_STOP_BIT). + slot.stop_requested.store(true, Ordering::Release); + // Thread is in bytecode — it will see `requested` and self-suspend + attached_seen = attached_seen.saturating_add(1); + } + // THREAD_SUSPENDED → already parked + } + if attached_seen != 0 { + self.stats_attached_seen + .fetch_add(attached_seen, Ordering::Relaxed); + } + if forced_parks != 0 { + self.decrement_thread_countdown(forced_parks); + self.stats_forced_parks + .fetch_add(forced_parks, Ordering::Relaxed); + } + forced_parks != 0 && self.thread_countdown.load(Ordering::Acquire) == 0 + } + + /// Stop all non-requester threads (`stop_the_world`). + /// + /// 1. Sets `requested`, marking the requester thread. + /// 2. CAS detached threads to SUSPENDED. + /// 3. Waits (polling with 1 ms condvar timeout) for attached threads + /// to self-suspend in `check_signals`. + pub fn stop_the_world(&self, vm: &VirtualMachine) { + let start = std::time::Instant::now(); + let requester_ident = crate::stdlib::_thread::get_ident(); + self.requester.store(requester_ident, Ordering::Relaxed); + self.stats_stop_calls.fetch_add(1, Ordering::Relaxed); + let initial_countdown = self.init_thread_countdown(vm); + stw_trace(format_args!("stop begin requester={requester_ident}")); + if initial_countdown == 0 { + self.world_stopped.store(true, Ordering::Release); + #[cfg(debug_assertions)] + self.debug_assert_all_non_requester_suspended(vm); + stw_trace(format_args!( + "stop end requester={requester_ident} wait_ns=0 polls=0" + )); + return; + } + + let mut polls = 0u64; + loop { + if self.park_detached_threads(vm) { + break; + } + polls = polls.saturating_add(1); + // Wait up to 1 ms for a thread to notify us it suspended. + // Re-check under the wait mutex first to avoid a lost-wake race: + // a thread may have suspended and notified right before we enter wait. + let guard = self.notify_mutex.lock().unwrap(); + if self.thread_countdown.load(Ordering::Acquire) == 0 || self.park_detached_threads(vm) + { + drop(guard); + break; + } + let _ = self + .notify_cv + .wait_timeout(guard, core::time::Duration::from_millis(1)); + } + if polls != 0 { + self.stats_poll_loops.fetch_add(polls, Ordering::Relaxed); + } + let wait_ns = start.elapsed().as_nanos().min(u128::from(u64::MAX)) as u64; + self.stats_last_wait_ns.store(wait_ns, Ordering::Relaxed); + self.stats_total_wait_ns + .fetch_add(wait_ns, Ordering::Relaxed); + let mut prev_max = self.stats_max_wait_ns.load(Ordering::Relaxed); + while wait_ns > prev_max { + match self.stats_max_wait_ns.compare_exchange_weak( + prev_max, + wait_ns, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(observed) => prev_max = observed, + } + } + self.world_stopped.store(true, Ordering::Release); + #[cfg(debug_assertions)] + self.debug_assert_all_non_requester_suspended(vm); + stw_trace(format_args!( + "stop end requester={requester_ident} wait_ns={wait_ns} polls={polls}" + )); + } + + /// Resume all suspended threads (`start_the_world`). + pub fn start_the_world(&self, vm: &VirtualMachine) { + use thread::{THREAD_DETACHED, THREAD_SUSPENDED}; + let requester = self.requester.load(Ordering::Relaxed); + stw_trace(format_args!("start begin requester={requester}")); + let registry = vm.state.thread_frames.lock(); + // Clear the request flag BEFORE waking threads. Otherwise a thread + // returning from allow_threads → attach_thread could observe + // `requested == true`, re-suspend itself, and stay parked forever. + // Keep this write under the registry lock to serialize with new + // thread-slot initialization. + self.requested.store(false, Ordering::Release); + self.world_stopped.store(false, Ordering::Release); + for (&id, slot) in registry.iter() { + if id == requester { + continue; + } + slot.stop_requested.store(false, Ordering::Release); + let state = slot.state.load(Ordering::Relaxed); + debug_assert!( + state == THREAD_SUSPENDED, + "non-requester thread not suspended at start-the-world: id={id} state={state}" + ); + if state == THREAD_SUSPENDED { + slot.state.store(THREAD_DETACHED, Ordering::Release); + slot.thread.unpark(); + } + } + drop(registry); + self.thread_countdown.store(0, Ordering::Release); + self.requester.store(0, Ordering::Relaxed); + #[cfg(debug_assertions)] + self.debug_assert_all_non_requester_detached(vm); + stw_trace(format_args!("start end requester={requester}")); + } + + /// Reset after fork in the child (only one thread alive). + pub fn reset_after_fork(&self) { + self.requested.store(false, Ordering::Relaxed); + self.world_stopped.store(false, Ordering::Relaxed); + self.requester.store(0, Ordering::Relaxed); + self.thread_countdown.store(0, Ordering::Relaxed); + stw_trace(format_args!("reset-after-fork")); + } + + #[inline] + pub(crate) fn requester_ident(&self) -> u64 { + self.requester.load(Ordering::Relaxed) + } + + #[inline] + pub(crate) fn notify_thread_gone(&self) { + let _guard = self.notify_mutex.lock().unwrap(); + self.decrement_thread_countdown(1); + self.notify_cv.notify_one(); + } + + pub fn stats_snapshot(&self) -> StopTheWorldStats { + StopTheWorldStats { + stop_calls: self.stats_stop_calls.load(Ordering::Relaxed), + last_wait_ns: self.stats_last_wait_ns.load(Ordering::Relaxed), + total_wait_ns: self.stats_total_wait_ns.load(Ordering::Relaxed), + max_wait_ns: self.stats_max_wait_ns.load(Ordering::Relaxed), + poll_loops: self.stats_poll_loops.load(Ordering::Relaxed), + attached_seen: self.stats_attached_seen.load(Ordering::Relaxed), + forced_parks: self.stats_forced_parks.load(Ordering::Relaxed), + suspend_notifications: self.stats_suspend_notifications.load(Ordering::Relaxed), + attach_wait_yields: self.stats_attach_wait_yields.load(Ordering::Relaxed), + suspend_wait_yields: self.stats_suspend_wait_yields.load(Ordering::Relaxed), + world_stopped: self.world_stopped.load(Ordering::Relaxed), + } + } + + pub fn reset_stats(&self) { + self.stats_stop_calls.store(0, Ordering::Relaxed); + self.stats_last_wait_ns.store(0, Ordering::Relaxed); + self.stats_total_wait_ns.store(0, Ordering::Relaxed); + self.stats_max_wait_ns.store(0, Ordering::Relaxed); + self.stats_poll_loops.store(0, Ordering::Relaxed); + self.stats_attached_seen.store(0, Ordering::Relaxed); + self.stats_forced_parks.store(0, Ordering::Relaxed); + self.stats_suspend_notifications.store(0, Ordering::Relaxed); + self.stats_attach_wait_yields.store(0, Ordering::Relaxed); + self.stats_suspend_wait_yields.store(0, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn add_attach_wait_yields(&self, n: u64) { + if n != 0 { + self.stats_attach_wait_yields + .fetch_add(n, Ordering::Relaxed); + } + } + + #[inline] + pub(crate) fn add_suspend_wait_yields(&self, n: u64) { + if n != 0 { + self.stats_suspend_wait_yields + .fetch_add(n, Ordering::Relaxed); + } + } + + #[cfg(debug_assertions)] + fn debug_assert_all_non_requester_suspended(&self, vm: &VirtualMachine) { + use thread::THREAD_SUSPENDED; + let requester = self.requester.load(Ordering::Relaxed); + let registry = vm.state.thread_frames.lock(); + for (&id, slot) in registry.iter() { + if id == requester { + continue; + } + let state = slot.state.load(Ordering::Relaxed); + debug_assert!( + state == THREAD_SUSPENDED, + "non-requester thread not suspended during stop-the-world: id={id} state={state}" + ); + } + } + + #[cfg(debug_assertions)] + fn debug_assert_all_non_requester_detached(&self, vm: &VirtualMachine) { + use thread::THREAD_SUSPENDED; + let requester = self.requester.load(Ordering::Relaxed); + let registry = vm.state.thread_frames.lock(); + for (&id, slot) in registry.iter() { + if id == requester { + continue; + } + let state = slot.state.load(Ordering::Relaxed); + debug_assert!( + state != THREAD_SUSPENDED, + "non-requester thread still suspended after start-the-world: id={id} state={state}" + ); + } + } +} + +#[cfg(all(unix, feature = "threading"))] +pub(super) fn stw_trace_enabled() -> bool { + static ENABLED: std::sync::OnceLock = std::sync::OnceLock::new(); + *ENABLED.get_or_init(|| std::env::var_os("RUSTPYTHON_STW_TRACE").is_some()) +} + +#[cfg(all(unix, feature = "threading"))] +pub(super) fn stw_trace(msg: core::fmt::Arguments<'_>) { + if stw_trace_enabled() { + use core::fmt::Write as _; + + // Avoid stdio locking here: this path runs around fork where a child + // may inherit a borrowed stderr lock and panic on eprintln!/stderr. + struct FixedBuf { + buf: [u8; 512], + len: usize, + } + + impl core::fmt::Write for FixedBuf { + fn write_str(&mut self, s: &str) -> core::fmt::Result { + if self.len >= self.buf.len() { + return Ok(()); + } + let remain = self.buf.len() - self.len; + let src = s.as_bytes(); + let n = src.len().min(remain); + self.buf[self.len..self.len + n].copy_from_slice(&src[..n]); + self.len += n; + Ok(()) + } + } + + let mut out = FixedBuf { + buf: [0u8; 512], + len: 0, + }; + let _ = writeln!( + &mut out, + "[rp-stw tid={}] {}", + crate::stdlib::_thread::get_ident(), + msg + ); + unsafe { + let _ = libc::write(libc::STDERR_FILENO, out.buf.as_ptr().cast(), out.len); + } + } +} + +#[derive(Clone, Debug, Default)] +pub(crate) struct CallableCache { + pub len: Option, + pub isinstance: Option, + pub list_append: Option, +} + pub struct PyGlobalState { pub config: PyConfig, pub module_defs: BTreeMap<&'static str, &'static builtins::PyModuleDef>, @@ -132,7 +585,7 @@ pub struct PyGlobalState { pub stacksize: AtomicCell, pub thread_count: AtomicCell, pub hash_secret: HashSecret, - pub atexit_funcs: PyMutex>, + pub atexit_funcs: PyMutex>>, pub codec_registry: CodecsRegistry, pub finalizing: AtomicBool, pub warnings: WarningsState, @@ -151,13 +604,13 @@ pub struct PyGlobalState { pub main_thread_ident: AtomicCell, /// Registry of all threads' slots for sys._current_frames() and sys._current_exceptions() #[cfg(feature = "threading")] - pub thread_frames: parking_lot::Mutex>, + pub thread_frames: parking_lot::Mutex>, /// Registry of all ThreadHandles for fork cleanup #[cfg(feature = "threading")] - pub thread_handles: parking_lot::Mutex>, + pub thread_handles: parking_lot::Mutex>, /// Registry for non-daemon threads that need to be joined at shutdown #[cfg(feature = "threading")] - pub shutdown_handles: parking_lot::Mutex>, + pub shutdown_handles: parking_lot::Mutex>, /// sys.monitoring state (tool names, events, callbacks) pub monitoring: PyMutex, /// Fast-path mask: OR of all tools' events. 0 means no monitoring overhead. @@ -165,6 +618,9 @@ pub struct PyGlobalState { /// Incremented on every monitoring state change. Code objects compare their /// local version against this to decide whether re-instrumentation is needed. pub instrumentation_version: AtomicU64, + /// Stop-the-world state for pre-fork thread suspension + #[cfg(all(unix, feature = "threading"))] + pub stop_the_world: StopTheWorldState, } pub fn process_hash_secret_seed() -> u32 { @@ -175,6 +631,19 @@ pub fn process_hash_secret_seed() -> u32 { } impl VirtualMachine { + fn init_callable_cache(&mut self) -> PyResult<()> { + self.callable_cache.len = Some(self.builtins.get_attr("len", self)?); + self.callable_cache.isinstance = Some(self.builtins.get_attr("isinstance", self)?); + let list_append = self + .ctx + .types + .list_type + .get_attr(self.ctx.intern_str("append")) + .ok_or_else(|| self.new_runtime_error("failed to cache list.append".to_owned()))?; + self.callable_cache.list_append = Some(list_append); + Ok(()) + } + /// Bump-allocate `size` bytes from the thread data stack. /// /// # Safety @@ -184,6 +653,12 @@ impl VirtualMachine { unsafe { (*self.datastack.get()).push(size) } } + /// Check whether the thread data stack currently has room for `size` bytes. + #[inline(always)] + pub(crate) fn datastack_has_space(&self, size: usize) -> bool { + unsafe { (*self.datastack.get()).has_space(size) } + } + /// Pop a previous data stack allocation. /// /// # Safety @@ -194,13 +669,23 @@ impl VirtualMachine { unsafe { (*self.datastack.get()).pop(base) } } + /// Temporarily detach the current thread (ATTACHED → DETACHED) while + /// running `f`, then re-attach afterwards. Allows `stop_the_world` to + /// park this thread during blocking syscalls. + /// + /// Equivalent to CPython's `Py_BEGIN_ALLOW_THREADS` / `Py_END_ALLOW_THREADS`. + #[inline] + pub fn allow_threads(&self, f: impl FnOnce() -> R) -> R { + thread::allow_threads(self, f) + } + /// Check whether the current thread is the main thread. /// Mirrors `_Py_ThreadCanHandleSignals`. #[allow(dead_code)] pub(crate) fn is_main_thread(&self) -> bool { #[cfg(feature = "threading")] { - crate::stdlib::thread::get_ident() == self.state.main_thread_ident.load() + crate::stdlib::_thread::get_ident() == self.state.main_thread_ident.load() } #[cfg(not(feature = "threading"))] { @@ -257,6 +742,7 @@ impl VirtualMachine { async_gen_finalizer: RefCell::new(None), asyncio_running_loop: RefCell::new(None), asyncio_running_task: RefCell::new(None), + callable_cache: CallableCache::default(), }; if vm.state.hash_secret.hash_str("") @@ -388,9 +874,11 @@ impl VirtualMachine { // Initialize main thread ident before any threading operations #[cfg(feature = "threading")] - stdlib::thread::init_main_thread_ident(self); + stdlib::_thread::init_main_thread_ident(self); stdlib::builtins::init_module(self, &self.builtins); + let callable_cache_init = self.init_callable_cache(); + self.expect_pyresult(callable_cache_init, "failed to initialize callable cache"); stdlib::sys::init_module(self, &self.sys_module, &self.builtins); self.expect_pyresult( stdlib::sys::set_bootstrap_stderr(self), @@ -416,10 +904,10 @@ impl VirtualMachine { let make_stdio = |name: &str, fd: i32, write: bool| -> PyResult { let buffered_stdio = self.state.config.settings.buffered_stdio; let unbuffered = write && !buffered_stdio; - let buf = crate::stdlib::io::open( + let buf = crate::stdlib::_io::open( self.ctx.new_int(fd).into(), Some(if write { "wb" } else { "rb" }), - crate::stdlib::io::OpenArgs { + crate::stdlib::_io::OpenArgs { buffering: if unbuffered { 0 } else { -1 }, closefd: false, ..Default::default() @@ -932,6 +1420,7 @@ impl VirtualMachine { /// Stack margin bytes (like _PyOS_STACK_MARGIN_BYTES). /// 2048 * sizeof(void*) = 16KB for 64-bit. + #[cfg_attr(miri, allow(dead_code))] const STACK_MARGIN_BYTES: usize = 2048 * core::mem::size_of::(); /// Get the stack boundaries using platform-specific APIs. @@ -1061,7 +1550,7 @@ impl VirtualMachine { frame: FrameRef, f: F, ) -> PyResult { - self.with_frame_exc(frame, None, f) + self.with_frame_impl(frame, None, true, f) } /// Like `with_frame` but allows specifying the initial exception state. @@ -1070,6 +1559,24 @@ impl VirtualMachine { frame: FrameRef, exc: Option, f: F, + ) -> PyResult { + self.with_frame_impl(frame, exc, true, f) + } + + pub(crate) fn with_frame_untraced PyResult>( + &self, + frame: FrameRef, + f: F, + ) -> PyResult { + self.with_frame_impl(frame, None, false, f) + } + + fn with_frame_impl PyResult>( + &self, + frame: FrameRef, + exc: Option, + traced: bool, + f: F, ) -> PyResult { self.with_recursion("", || { // SAFETY: `frame` (FrameRef) stays alive for the entire closure scope, @@ -1105,7 +1612,11 @@ impl VirtualMachine { crate::vm::thread::pop_thread_frame(); } - self.dispatch_traced_frame(&frame, |frame| f(frame.to_owned())) + if traced { + self.dispatch_traced_frame(&frame, |frame| f(frame.to_owned())) + } else { + f(frame.to_owned()) + } }) } @@ -1471,6 +1982,26 @@ impl VirtualMachine { self.get_method(obj, method_name) } + #[inline] + pub(crate) fn eval_breaker_tripped(&self) -> bool { + #[cfg(feature = "threading")] + if self.state.finalizing.load(Ordering::Relaxed) && !self.is_main_thread() { + return true; + } + + #[cfg(all(unix, feature = "threading"))] + if thread::stop_requested_for_current_thread() { + return true; + } + + #[cfg(not(target_arch = "wasm32"))] + if crate::signal::is_triggered() { + return true; + } + + false + } + #[inline] /// Checks for triggered signals and calls the appropriate handlers. A no-op on /// platforms where signals are not supported. @@ -1482,6 +2013,10 @@ impl VirtualMachine { return Err(self.new_exception(self.ctx.exceptions.system_exit.to_owned(), vec![])); } + // Suspend this thread if stop-the-world is in progress + #[cfg(all(unix, feature = "threading"))] + thread::suspend_if_needed(&self.state.stop_the_world); + #[cfg(not(target_arch = "wasm32"))] { crate::signal::check_signals(self) diff --git a/crates/vm/src/vm/python_run.rs b/crates/vm/src/vm/python_run.rs index 70d845b03f5..2f6f0bbee01 100644 --- a/crates/vm/src/vm/python_run.rs +++ b/crates/vm/src/vm/python_run.rs @@ -97,9 +97,9 @@ mod file_run { let loader = module_dict.get_item("__loader__", self)?; let get_code = loader.get_attr("get_code", self)?; let code_obj = get_code.call((identifier!(self, __main__).to_owned(),), self)?; - let code = code_obj.downcast::().map_err(|_| { - self.new_runtime_error("Bad code object in .pyc file".to_owned()) - })?; + let code = code_obj + .downcast::() + .map_err(|_| self.new_runtime_error("Bad code object in .pyc file"))?; self.run_code_obj(code, scope)?; } else { if path != "" { diff --git a/crates/vm/src/vm/thread.rs b/crates/vm/src/vm/thread.rs index 8dd8e0312ee..e7cc64f00b4 100644 --- a/crates/vm/src/vm/thread.rs +++ b/crates/vm/src/vm/thread.rs @@ -14,6 +14,17 @@ use core::{ use itertools::Itertools; use std::thread_local; +// Thread states for stop-the-world support. +// DETACHED: not executing Python bytecode (in native code, or idle) +// ATTACHED: actively executing Python bytecode +// SUSPENDED: parked by a stop-the-world request +#[cfg(all(unix, feature = "threading"))] +pub const THREAD_DETACHED: i32 = 0; +#[cfg(all(unix, feature = "threading"))] +pub const THREAD_ATTACHED: i32 = 1; +#[cfg(all(unix, feature = "threading"))] +pub const THREAD_SUSPENDED: i32 = 2; + /// Per-thread shared state for sys._current_frames() and sys._current_exceptions(). /// The exception field uses atomic operations for lock-free cross-thread reads. #[cfg(feature = "threading")] @@ -22,6 +33,15 @@ pub struct ThreadSlot { /// Readers must hold the Mutex and convert to FrameRef inside the lock. pub frames: parking_lot::Mutex>, pub exception: crate::PyAtomicRef>, + /// Thread state for stop-the-world: DETACHED / ATTACHED / SUSPENDED + #[cfg(unix)] + pub state: core::sync::atomic::AtomicI32, + /// Per-thread stop request bit (eval breaker equivalent). + #[cfg(unix)] + pub stop_requested: core::sync::atomic::AtomicBool, + /// Handle for waking this thread from park in stop-the-world paths. + #[cfg(unix)] + pub thread: std::thread::Thread, } #[cfg(feature = "threading")] @@ -57,13 +77,29 @@ pub fn with_current_vm(f: impl FnOnce(&VirtualMachine) -> R) -> R { pub fn enter_vm(vm: &VirtualMachine, f: impl FnOnce() -> R) -> R { VM_STACK.with(|vms| { + // Outermost enter_vm: transition DETACHED → ATTACHED + #[cfg(all(unix, feature = "threading"))] + let was_outermost = vms.borrow().is_empty(); + vms.borrow_mut().push(vm.into()); // Initialize thread slot for this thread if not already done #[cfg(feature = "threading")] init_thread_slot_if_needed(vm); - scopeguard::defer! { vms.borrow_mut().pop(); } + #[cfg(all(unix, feature = "threading"))] + if was_outermost { + attach_thread(vm); + } + + scopeguard::defer! { + // Outermost exit: transition ATTACHED → DETACHED + #[cfg(all(unix, feature = "threading"))] + if vms.borrow().len() == 1 { + detach_thread(); + } + vms.borrow_mut().pop(); + } VM_CURRENT.set(vm, f) }) } @@ -74,20 +110,249 @@ pub fn enter_vm(vm: &VirtualMachine, f: impl FnOnce() -> R) -> R { fn init_thread_slot_if_needed(vm: &VirtualMachine) { CURRENT_THREAD_SLOT.with(|slot| { if slot.borrow().is_none() { - let thread_id = crate::stdlib::thread::get_ident(); + let thread_id = crate::stdlib::_thread::get_ident(); + let mut registry = vm.state.thread_frames.lock(); let new_slot = Arc::new(ThreadSlot { frames: parking_lot::Mutex::new(Vec::new()), exception: crate::PyAtomicRef::from(None::), + #[cfg(unix)] + state: core::sync::atomic::AtomicI32::new( + if vm.state.stop_the_world.requested.load(Ordering::Acquire) { + // Match init_threadstate(): new thread-state starts + // suspended while stop-the-world is active. + THREAD_SUSPENDED + } else { + THREAD_DETACHED + }, + ), + #[cfg(unix)] + stop_requested: core::sync::atomic::AtomicBool::new(false), + #[cfg(unix)] + thread: std::thread::current(), }); - vm.state - .thread_frames - .lock() - .insert(thread_id, new_slot.clone()); + registry.insert(thread_id, new_slot.clone()); + drop(registry); *slot.borrow_mut() = Some(new_slot); } }); } +/// Transition DETACHED → ATTACHED. Blocks if the thread was SUSPENDED by +/// a stop-the-world request (like `_PyThreadState_Attach` + `tstate_wait_attach`). +#[cfg(all(unix, feature = "threading"))] +fn wait_while_suspended(slot: &ThreadSlot) -> u64 { + let mut wait_yields = 0u64; + while slot.state.load(Ordering::Acquire) == THREAD_SUSPENDED { + wait_yields = wait_yields.saturating_add(1); + std::thread::park(); + } + wait_yields +} + +#[cfg(all(unix, feature = "threading"))] +fn attach_thread(vm: &VirtualMachine) { + CURRENT_THREAD_SLOT.with(|slot| { + if let Some(s) = slot.borrow().as_ref() { + super::stw_trace(format_args!("attach begin")); + loop { + match s.state.compare_exchange( + THREAD_DETACHED, + THREAD_ATTACHED, + Ordering::AcqRel, + Ordering::Relaxed, + ) { + Ok(_) => { + super::stw_trace(format_args!("attach DETACHED->ATTACHED")); + break; + } + Err(THREAD_SUSPENDED) => { + // Parked by stop-the-world — wait until released to DETACHED + super::stw_trace(format_args!("attach wait-suspended")); + let wait_yields = wait_while_suspended(s); + vm.state.stop_the_world.add_attach_wait_yields(wait_yields); + // Retry CAS + } + Err(state) => { + debug_assert!(false, "unexpected thread state in attach: {state}"); + break; + } + } + } + } + }); +} + +/// Transition ATTACHED → DETACHED (like `_PyThreadState_Detach`). +#[cfg(all(unix, feature = "threading"))] +fn detach_thread() { + CURRENT_THREAD_SLOT.with(|slot| { + if let Some(s) = slot.borrow().as_ref() { + match s.state.compare_exchange( + THREAD_ATTACHED, + THREAD_DETACHED, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => {} + Err(THREAD_DETACHED) => { + debug_assert!(false, "detach called while already DETACHED"); + return; + } + Err(state) => { + debug_assert!(false, "unexpected thread state in detach: {state}"); + return; + } + } + super::stw_trace(format_args!("detach ATTACHED->DETACHED")); + } + }); +} + +/// Temporarily transition the current thread ATTACHED → DETACHED while +/// running `f`, then re-attach afterwards. This allows `stop_the_world` +/// to park this thread during blocking operations. +/// +/// `Py_BEGIN_ALLOW_THREADS` / `Py_END_ALLOW_THREADS` equivalent. +#[cfg(all(unix, feature = "threading"))] +pub fn allow_threads(vm: &VirtualMachine, f: impl FnOnce() -> R) -> R { + // Preserve save/restore semantics: + // only detach if this call observed ATTACHED at entry, and always restore + // on unwind. + let should_transition = CURRENT_THREAD_SLOT.with(|slot| { + slot.borrow() + .as_ref() + .is_some_and(|s| s.state.load(Ordering::Acquire) == THREAD_ATTACHED) + }); + if !should_transition { + return f(); + } + + detach_thread(); + let reattach_guard = scopeguard::guard(vm, attach_thread); + let result = f(); + drop(reattach_guard); + result +} + +/// No-op on non-unix or non-threading builds. +#[cfg(not(all(unix, feature = "threading")))] +pub fn allow_threads(_vm: &VirtualMachine, f: impl FnOnce() -> R) -> R { + f() +} + +/// Called from check_signals when stop-the-world is requested. +/// Transitions ATTACHED → SUSPENDED and waits until released +/// (like `_PyThreadState_Suspend` + `_PyThreadState_Attach`). +#[cfg(all(unix, feature = "threading"))] +pub fn suspend_if_needed(stw: &super::StopTheWorldState) { + let should_suspend = CURRENT_THREAD_SLOT.with(|slot| { + slot.borrow() + .as_ref() + .is_some_and(|s| s.stop_requested.load(Ordering::Relaxed)) + }); + if !should_suspend { + return; + } + + if !stw.requested.load(Ordering::Acquire) { + CURRENT_THREAD_SLOT.with(|slot| { + if let Some(s) = slot.borrow().as_ref() { + s.stop_requested.store(false, Ordering::Release); + } + }); + return; + } + + do_suspend(stw); +} + +#[cfg(all(unix, feature = "threading"))] +#[cold] +fn do_suspend(stw: &super::StopTheWorldState) { + CURRENT_THREAD_SLOT.with(|slot| { + if let Some(s) = slot.borrow().as_ref() { + // ATTACHED → SUSPENDED + match s.state.compare_exchange( + THREAD_ATTACHED, + THREAD_SUSPENDED, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + // Consumed this thread's stop request bit. + s.stop_requested.store(false, Ordering::Release); + } + Err(THREAD_DETACHED) => { + // Leaving VM; caller will re-check on next entry. + super::stw_trace(format_args!("suspend skip DETACHED")); + return; + } + Err(THREAD_SUSPENDED) => { + // Already parked by another path. + s.stop_requested.store(false, Ordering::Release); + super::stw_trace(format_args!("suspend skip already-suspended")); + return; + } + Err(state) => { + debug_assert!(false, "unexpected thread state in suspend: {state}"); + return; + } + } + super::stw_trace(format_args!("suspend ATTACHED->SUSPENDED")); + + // Re-check: if start_the_world already ran (cleared `requested`), + // no one will set us back to DETACHED — we must self-recover. + if !stw.requested.load(Ordering::Acquire) { + s.state.store(THREAD_ATTACHED, Ordering::Release); + s.stop_requested.store(false, Ordering::Release); + super::stw_trace(format_args!("suspend abort requested-cleared")); + return; + } + + // Notify the stop-the-world requester that we've parked + stw.notify_suspended(); + super::stw_trace(format_args!("suspend notified-requester")); + + // Wait until start_the_world sets us back to DETACHED + let wait_yields = wait_while_suspended(s); + stw.add_suspend_wait_yields(wait_yields); + + // Re-attach (DETACHED → ATTACHED), tstate_wait_attach CAS loop. + loop { + match s.state.compare_exchange( + THREAD_DETACHED, + THREAD_ATTACHED, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => break, + Err(THREAD_SUSPENDED) => { + let extra_wait = wait_while_suspended(s); + stw.add_suspend_wait_yields(extra_wait); + } + Err(THREAD_ATTACHED) => break, + Err(state) => { + debug_assert!(false, "unexpected post-suspend state: {state}"); + break; + } + } + } + s.stop_requested.store(false, Ordering::Release); + super::stw_trace(format_args!("suspend resume -> ATTACHED")); + } + }); +} + +#[cfg(all(unix, feature = "threading"))] +#[inline] +pub fn stop_requested_for_current_thread() -> bool { + CURRENT_THREAD_SLOT.with(|slot| { + slot.borrow() + .as_ref() + .is_some_and(|s| s.stop_requested.load(Ordering::Relaxed)) + }) +} + /// Push a frame pointer onto the current thread's shared frame stack. /// The pointed-to frame must remain alive until the matching pop. #[cfg(feature = "threading")] @@ -159,8 +424,42 @@ pub fn get_all_current_exceptions(vm: &VirtualMachine) -> Vec<(u64, Option registry.remove(&thread_id), + _ => None, + } + } else { + None + }; + #[cfg(all(unix, feature = "threading"))] + if let Some(slot) = &removed + && vm.state.stop_the_world.requested.load(Ordering::Acquire) + && thread_id != vm.state.stop_the_world.requester_ident() + && slot.state.load(Ordering::Relaxed) != THREAD_SUSPENDED + { + // A non-requester thread disappeared while stop-the-world is pending. + // Unblock requester countdown progress. + vm.state.stop_the_world.notify_thread_gone(); + } CURRENT_THREAD_SLOT.with(|s| { *s.borrow_mut() = None; }); @@ -174,11 +473,17 @@ pub fn cleanup_current_thread_frames(vm: &VirtualMachine) { /// VmState locks to unlocked. #[cfg(feature = "threading")] pub fn reinit_frame_slot_after_fork(vm: &VirtualMachine) { - let current_ident = crate::stdlib::thread::get_ident(); + let current_ident = crate::stdlib::_thread::get_ident(); let current_frames: Vec = vm.frames.borrow().clone(); let new_slot = Arc::new(ThreadSlot { frames: parking_lot::Mutex::new(current_frames), exception: crate::PyAtomicRef::from(vm.topmost_exception()), + #[cfg(unix)] + state: core::sync::atomic::AtomicI32::new(THREAD_ATTACHED), + #[cfg(unix)] + stop_requested: core::sync::atomic::AtomicBool::new(false), + #[cfg(unix)] + thread: std::thread::current(), }); // Lock is safe: reinit_locks_after_fork() already reset it to unlocked. @@ -329,6 +634,7 @@ impl VirtualMachine { async_gen_finalizer: RefCell::new(None), asyncio_running_loop: RefCell::new(None), asyncio_running_task: RefCell::new(None), + callable_cache: self.callable_cache.clone(), }; ThreadedVirtualMachine { vm } } diff --git a/crates/vm/src/vm/vm_ops.rs b/crates/vm/src/vm/vm_ops.rs index 1a362d67bed..d5c70f87386 100644 --- a/crates/vm/src/vm/vm_ops.rs +++ b/crates/vm/src/vm/vm_ops.rs @@ -1,5 +1,5 @@ use super::VirtualMachine; -use crate::stdlib::warnings; +use crate::stdlib::_warnings; use crate::{ PyRef, builtins::{PyInt, PyStr, PyStrRef, PyUtf8Str}, @@ -485,7 +485,7 @@ impl VirtualMachine { This returns the bitwise inversion of the underlying int object and is usually not what you expect from negating a bool. \ Use the 'not' operator for boolean negation or ~int(x) if you really want the bitwise inversion of the underlying int."; if a.fast_isinstance(self.ctx.types.bool_type) { - warnings::warn( + _warnings::warn( self.ctx.exceptions.deprecation_warning, STR.to_owned(), 1, diff --git a/crates/vm/src/warn.rs b/crates/vm/src/warn.rs index 7cefed3b1ae..0bae7a9619a 100644 --- a/crates/vm/src/warn.rs +++ b/crates/vm/src/warn.rs @@ -390,7 +390,7 @@ pub(crate) fn warn_explicit( vm, )?; let action_str = PyStrRef::try_from_object(vm, action) - .map_err(|_| vm.new_type_error("action must be a string".to_owned()))?; + .map_err(|_| vm.new_type_error("action must be a string"))?; if action_str.as_bytes() == b"error" { let exc = PyBaseExceptionRef::try_from_object(vm, message)?; @@ -470,13 +470,11 @@ fn call_show_warning( return show_warning(filename, lineno, text, category, source_line, vm); }; if !show_fn.is_callable() { - return Err( - vm.new_type_error("warnings._showwarnmsg() must be set to a callable".to_owned()) - ); + return Err(vm.new_type_error("warnings._showwarnmsg() must be set to a callable")); } let Some(warnmsg_cls) = get_warnings_attr(vm, identifier!(&vm.ctx, WarningMessage), false)? else { - return Err(vm.new_runtime_error("unable to get warnings.WarningMessage".to_owned())); + return Err(vm.new_runtime_error("unable to get warnings.WarningMessage")); }; let msg = warnmsg_cls.call( @@ -591,7 +589,7 @@ fn setup_context( .get_attr(identifier!(vm, __dict__), vm) .and_then(|d| { d.downcast::() - .map_err(|_| vm.new_type_error("sys.__dict__ is not a dictionary".to_owned())) + .map_err(|_| vm.new_type_error("sys.__dict__ is not a dictionary")) })?; (globals, vm.ctx.intern_str(""), 0) }; diff --git a/crates/wasm/src/convert.rs b/crates/wasm/src/convert.rs index a0186ce2834..bbf263975f3 100644 --- a/crates/wasm/src/convert.rs +++ b/crates/wasm/src/convert.rs @@ -49,7 +49,15 @@ pub fn py_err_to_js_err(vm: &VirtualMachine, py_err: &Py) -> Js serde_wasm_bindgen::to_value(&exceptions::SerializeException::new(vm, py_err)); match res { Ok(err_info) => PyError::new(err_info).into(), - Err(e) => e.into(), + Err(_) => { + // Fallback: create a basic JS Error with the exception type and message + let exc_type = py_err.class().name().to_string(); + let msg = match py_err.as_object().str(vm) { + Ok(s) => format!("{exc_type}: {s}"), + Err(_) => exc_type, + }; + js_sys::Error::new(&msg).into() + } } } } diff --git a/extra_tests/snippets/vm_specialization.py b/extra_tests/snippets/vm_specialization.py new file mode 100644 index 00000000000..2c884cc2f6d --- /dev/null +++ b/extra_tests/snippets/vm_specialization.py @@ -0,0 +1,71 @@ +## BinaryOp inplace-add unicode: deopt falls back to __add__/__iadd__ + + +class S(str): + def __add__(self, other): + return "ADD" + + def __iadd__(self, other): + return "IADD" + + +def add_path_fallback_uses_add(): + x = "a" + y = "b" + for i in range(1200): + if i == 600: + x = S("s") + y = "t" + x = x + y + return x + + +def iadd_path_fallback_uses_iadd(): + x = "a" + y = "b" + for i in range(1200): + if i == 600: + x = S("s") + y = "t" + x += y + return x + + +assert add_path_fallback_uses_add().startswith("ADD") +assert iadd_path_fallback_uses_iadd().startswith("IADD") + + +## BINARY_SUBSCR_STR_INT: ASCII singleton identity + + +def check_ascii_subscr_singleton_after_warmup(): + s = "abc" + first = None + for i in range(4000): + c = s[0] + if i >= 3500: + if first is None: + first = c + else: + assert c is first + + +check_ascii_subscr_singleton_after_warmup() + + +## BINARY_SUBSCR_STR_INT: Latin-1 singleton identity + + +def check_latin1_subscr_singleton_after_warmup(): + for s in ("abc", "éx"): + first = None + for i in range(5000): + c = s[0] + if i >= 4500: + if first is None: + first = c + else: + assert c is first + + +check_latin1_subscr_singleton_after_warmup() diff --git a/extra_tests/test_manager_fork_debug.py b/extra_tests/test_manager_fork_debug.py new file mode 100644 index 00000000000..6110f7e3699 --- /dev/null +++ b/extra_tests/test_manager_fork_debug.py @@ -0,0 +1,149 @@ +"""Minimal reproduction of multiprocessing Manager + fork failure.""" + +import multiprocessing +import os +import sys +import time +import traceback + +import pytest + +pytestmark = pytest.mark.skipif(not hasattr(os, "fork"), reason="requires os.fork") + + +def test_basic_manager(): + """Test Manager without fork - does it work at all?""" + print("=== Test 1: Basic Manager (no fork) ===") + ctx = multiprocessing.get_context("fork") + manager = ctx.Manager() + try: + ev = manager.Event() + print(f" Event created: {ev}") + ev.set() + print(f" Event set, is_set={ev.is_set()}") + assert ev.is_set() + print(" PASS") + finally: + manager.shutdown() + + +def test_manager_with_process(): + """Test Manager shared between parent and child process.""" + print("\n=== Test 2: Manager with forked child ===") + ctx = multiprocessing.get_context("fork") + manager = ctx.Manager() + try: + result = manager.Value("i", 0) + ev = manager.Event() + + def child_fn(): + try: + ev.set() + result.value = 42 + except Exception as e: + print(f" CHILD ERROR: {e}", file=sys.stderr) + traceback.print_exc() + sys.exit(1) + + print(f" Starting child process...") + process = ctx.Process(target=child_fn) + process.start() + print(f" Waiting for child (pid={process.pid})...") + process.join(timeout=10) + + if process.exitcode != 0: + print(f" FAIL: child exited with code {process.exitcode}") + return False + + print(f" Child done. result={result.value}, event={ev.is_set()}") + assert result.value == 42 + assert ev.is_set() + print(" PASS") + return True + finally: + manager.shutdown() + + +def test_manager_server_alive_after_fork(): + """Test that Manager server survives after forking a child.""" + print("\n=== Test 3: Manager server alive after fork ===") + ctx = multiprocessing.get_context("fork") + manager = ctx.Manager() + try: + ev = manager.Event() + + # Fork a child that does nothing with the manager + pid = os.fork() + if pid == 0: + # Child - exit immediately + os._exit(0) + + # Parent - wait for child + os.waitpid(pid, 0) + + # Now try to use the manager in the parent + print(f" After fork, trying to use Manager in parent...") + ev.set() + print(f" ev.is_set() = {ev.is_set()}") + assert ev.is_set() + print(" PASS") + return True + finally: + manager.shutdown() + + +def test_manager_server_alive_after_fork_with_child_usage(): + """Test that Manager server survives when child also uses it.""" + print("\n=== Test 4: Manager server alive after fork + child usage ===") + ctx = multiprocessing.get_context("fork") + manager = ctx.Manager() + try: + child_ev = manager.Event() + parent_ev = manager.Event() + + def child_fn(): + try: + child_ev.set() + except Exception as e: + print(f" CHILD ERROR: {e}", file=sys.stderr) + traceback.print_exc() + sys.exit(1) + + process = ctx.Process(target=child_fn) + process.start() + process.join(timeout=10) + + if process.exitcode != 0: + print(f" FAIL: child exited with code {process.exitcode}") + return False + + # Now use manager in parent AFTER child is done + print(f" Child done. Trying parent usage...") + parent_ev.set() + print(f" child_ev={child_ev.is_set()}, parent_ev={parent_ev.is_set()}") + assert child_ev.is_set() + assert parent_ev.is_set() + print(" PASS") + return True + finally: + manager.shutdown() + + +if __name__ == "__main__": + test_basic_manager() + + passed = 0 + total = 10 + for i in range(total): + print(f"\n--- Iteration {i + 1}/{total} ---") + ok = True + ok = ok and test_manager_with_process() + ok = ok and test_manager_server_alive_after_fork() + ok = ok and test_manager_server_alive_after_fork_with_child_usage() + if ok: + passed += 1 + else: + print(f" FAILED on iteration {i + 1}") + + print(f"\n=== Results: {passed}/{total} passed ===") + sys.exit(0 if passed == total else 1) diff --git a/scripts/whats_left.py b/scripts/whats_left.py index 00db9a0ac5c..9a4d57df6ae 100755 --- a/scripts/whats_left.py +++ b/scripts/whats_left.py @@ -67,9 +67,9 @@ def parse_args(): ) parser.add_argument( "--features", - action="store", - help="which features to enable when building RustPython (default: ssl)", - default="ssl", + action="append", + help="which features to enable when building RustPython (default: [])", + default=[], ) args = parser.parse_args() @@ -449,16 +449,20 @@ def remove_one_indent(s): cargo_build_command = ["cargo", "build", "--release"] if args.no_default_features: cargo_build_command.append("--no-default-features") + +joined_features = ",".join(args.features) if args.features: - cargo_build_command.extend(["--features", args.features]) + cargo_build_command.extend(["--features", joined_features]) subprocess.run(cargo_build_command, check=True) cargo_run_command = ["cargo", "run", "--release"] if args.no_default_features: cargo_run_command.append("--no-default-features") + if args.features: - cargo_run_command.extend(["--features", args.features]) + cargo_run_command.extend(["--features", joined_features]) + cargo_run_command.extend(["-q", "--", GENERATED_FILE]) result = subprocess.run(