Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "discord-cluster-manager"
version = "0.1.0"
description = "Discord bot for managing compute clusters and running kernel benchmarks"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
dependencies = [
"PyGithub",
"aiohttp",
Expand All @@ -25,6 +25,7 @@ dependencies = [
"jinja2",
"huggingface-hub>=0.20",
"pyarrow>=14.0",
"kernelguard>=0.1.1",
]

[project.optional-dependencies]
Expand Down
9 changes: 8 additions & 1 deletion src/kernelbot/api/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from kernelbot.env import env
from libkernelbot.backend import KernelBackend
from libkernelbot.consts import SubmissionMode
from libkernelbot.kernelguard import KernelGuardRejected
from libkernelbot.leaderboard_db import LeaderboardDB
from libkernelbot.report import (
Log,
Expand All @@ -18,6 +19,7 @@
SubmissionRequest,
prepare_submission,
)
from libkernelbot.utils import KernelBotError


async def _handle_discord_oauth(code: str, redirect_uri: str) -> tuple[str, str]:
Expand Down Expand Up @@ -154,7 +156,12 @@ async def _run_submission(
raise HTTPException(status_code=400, detail="Invalid GPU type")

reporter = MultiProgressReporterAPI()
sub_id, results = await backend.submit_full(req, mode, reporter)
try:
sub_id, results = await backend.submit_full(req, mode, reporter)
except KernelGuardRejected as e:
raise HTTPException(status_code=400, detail=str(e)) from e
except KernelBotError as e:
raise HTTPException(status_code=getattr(e, "http_code", 500), detail=str(e)) from e
return results, [rep.get_message() + "\n" + rep.long_report for rep in reporter.runs]


Expand Down
13 changes: 13 additions & 0 deletions src/kernelbot/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from libkernelbot.background_submission_manager import BackgroundSubmissionManager
from libkernelbot.consts import SubmissionMode
from libkernelbot.db_types import IdentityType
from libkernelbot.kernelguard import KernelGuardRejected, enforce_submission_precheck, should_precheck_submission
from libkernelbot.leaderboard_db import LeaderboardDB, LeaderboardRankedEntry
from libkernelbot.problem_sync import sync_problems
from libkernelbot.submission import (
Expand Down Expand Up @@ -563,6 +564,18 @@ async def run_submission_async(
if not req.gpus or len(req.gpus) != 1:
raise HTTPException(status_code=400, detail="Invalid GPU type")

# run KernelGuard pre-check before enqueuing to avoid filling the queue with blocked submissions
if should_precheck_submission(submission_mode_enum):
try:
await asyncio.wait_for(
asyncio.to_thread(enforce_submission_precheck, req.code, req.file_name),
timeout=5.0,
)
except asyncio.TimeoutError as e:
raise HTTPException(status_code=504, detail="KernelGuard pre-check timed out") from e
except KernelGuardRejected as e:
raise HTTPException(status_code=400, detail=str(e)) from e

# put submission request to background manager to run in background
sub_id, job_status_id = await enqueue_background_job(
req, submission_mode_enum, backend_instance, background_submission_manager
Expand Down
35 changes: 32 additions & 3 deletions src/libkernelbot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from typing import Optional

from libkernelbot.consts import GPU, GPU_TO_SM, SubmissionMode, get_gpu_by_name, get_mode_category
from libkernelbot.kernelguard import (
KernelGuardRejected,
enforce_submission_precheck,
should_precheck_submission,
)
from libkernelbot.launchers import Launcher
from libkernelbot.leaderboard_db import LeaderboardDB
from libkernelbot.report import (
Expand Down Expand Up @@ -53,10 +58,11 @@ async def submit_full(
mode: SubmissionMode,
reporter: MultiProgressReporter,
pre_sub_id: Optional[int] = None,
skip_precheck: bool = False,
):
"""
pre_sub_id is used to pass the submission id which is created beforehand.

skip_precheck skips the KernelGuard pre-check (use when the caller already ran it).
"""
if pre_sub_id is not None:
sub_id = pre_sub_id
Expand All @@ -72,7 +78,29 @@ async def submit_full(
mode_category=req.mode_category or get_mode_category(mode),
)
selected_gpus = [get_gpu_by_name(gpu) for gpu in req.gpus]
submission_started = False
try:
if not skip_precheck and should_precheck_submission(mode):
try:
await asyncio.to_thread(enforce_submission_precheck, req.code, req.file_name)
except KernelGuardRejected as exc:
logger.error(
"Submission %s rejected by precheck: file=%s, mode=%s, error=%s",
sub_id, req.file_name, mode, str(exc)
)
with self.db as db:
db.mark_submission_hacked(sub_id, error=str(exc))
raise
except Exception as exc:
logger.error(
"Submission %s precheck unavailable: file=%s, mode=%s, error=%s",
sub_id, req.file_name, mode, str(exc)
)
with self.db as db:
db.mark_submission_done(sub_id)
raise

submission_started = True
tasks = [
self.submit_leaderboard(
sub_id,
Expand Down Expand Up @@ -106,8 +134,9 @@ async def submit_full(
)
results = await asyncio.gather(*tasks)
finally:
with self.db as db:
db.mark_submission_done(sub_id)
if submission_started:
with self.db as db:
db.mark_submission_done(sub_id)
return sub_id, results

async def submit_leaderboard( # noqa: C901
Expand Down
19 changes: 18 additions & 1 deletion src/libkernelbot/background_submission_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from libkernelbot.backend import KernelBackend
from libkernelbot.consts import SubmissionMode
from libkernelbot.kernelguard import KernelGuardRejected
from libkernelbot.report import MultiProgressReporter, RunProgressReporter, RunResultReport
from libkernelbot.submission import ProcessedSubmissionRequest
from libkernelbot.utils import setup_logging
Expand Down Expand Up @@ -233,7 +234,7 @@ async def heartbeat():
reporter = BackgroundSubmissionManagerReporter()
await asyncio.wait_for(
self.backend.submit_full(
item.req, item.mode, reporter, sub_id
item.req, item.mode, reporter, sub_id, skip_precheck=True
),
timeout=HARD_TIMEOUT_SEC,
)
Expand All @@ -252,6 +253,22 @@ async def heartbeat():
last_heartbeat=ts,
error="hard timeout reached",
)
except KernelGuardRejected as e:
ts = dt.datetime.now(dt.timezone.utc)
logger.info("[Background Job] submission %s flagged as hacked", sub_id)
try:
with self.backend.db as db:
db.upsert_submission_job_status(
sub_id,
status="hacked",
last_heartbeat=ts,
error=str(e),
)
except Exception:
logger.error(
"[Background Job] Failed to write hacked status for submission %s",
sub_id,
)
except Exception as e:
ts = dt.datetime.now(dt.timezone.utc)
logger.error(
Expand Down
159 changes: 159 additions & 0 deletions src/libkernelbot/kernelguard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import json
import os
import shlex
import shutil
import subprocess
from typing import Any

from libkernelbot.consts import SubmissionMode
from libkernelbot.utils import KernelBotError, limit_length, setup_logging

logger = setup_logging(__name__)

_TRUE_VALUES = {"1", "true", "yes", "on"}
_DEFAULT_TIMEOUT_SEC = 30
_GUARDED_MODES = frozenset(
{
SubmissionMode.BENCHMARK,
SubmissionMode.PROFILE,
SubmissionMode.LEADERBOARD,
SubmissionMode.PRIVATE,
}
)


class KernelGuardRejected(KernelBotError):
def __init__(self, message: str, result: dict[str, Any]):
super().__init__(message)
self.result = result


def _env_enabled(name: str, default: bool = False) -> bool:
raw = os.getenv(name)
if raw is None:
return default
return raw.strip().lower() in _TRUE_VALUES


def should_precheck_submission(mode: SubmissionMode) -> bool:
return _env_enabled("KERNELGUARD_ENABLED") and mode in _GUARDED_MODES


def _timeout_sec() -> int:
raw = os.getenv("KERNELGUARD_TIMEOUT_SEC", str(_DEFAULT_TIMEOUT_SEC)).strip()
try:
return max(1, int(raw))
except ValueError:
logger.warning("Invalid KERNELGUARD_TIMEOUT_SEC=%r, using %d", raw, _DEFAULT_TIMEOUT_SEC)
return _DEFAULT_TIMEOUT_SEC


def _profile() -> str | None:
raw = os.getenv("KERNELGUARD_PROFILE", "").strip()
return raw or None


def _config_path() -> str | None:
raw = os.getenv("KERNELGUARD_CONFIG", "").strip()
return raw or None


def _fail_open_enabled() -> bool:
return _env_enabled("KERNELGUARD_FAIL_OPEN")


def _default_command() -> list[str]:
for candidate in ("kernelguard", "kguard"):
if shutil.which(candidate):
return [candidate]
if shutil.which("uvx"):
return ["uvx", "kernelguard"]
raise FileNotFoundError("Could not find `kernelguard`, `kguard`, or `uvx` in PATH")


def _command() -> list[str]:
raw = os.getenv("KERNELGUARD_COMMAND", "").strip()
if raw:
return shlex.split(raw)
return _default_command()


def _analyze_with_cli(code: str) -> dict[str, Any]:
cmd = [*_command()]
profile = _profile()
config_path = _config_path()
if profile is not None:
cmd.extend(["--profile", profile])
if config_path is not None:
cmd.extend(["--config", config_path])
cmd.append("--api-mode")

proc = subprocess.run(
cmd,
input=code,
text=True,
capture_output=True,
timeout=_timeout_sec(),
check=False,
)
if proc.returncode != 0:
stderr = limit_length(proc.stderr.strip(), 300) if proc.stderr else ""
stdout = limit_length(proc.stdout.strip(), 300) if proc.stdout else ""
raise RuntimeError(
"KernelGuard command failed "
f"(exit={proc.returncode}, stdout={stdout!r}, stderr={stderr!r})"
)

lines = [line for line in proc.stdout.splitlines() if line.strip()]
if not lines:
raise RuntimeError("KernelGuard returned no JSON result")

try:
result = json.loads(lines[-1])
except json.JSONDecodeError as exc:
raise RuntimeError(f"KernelGuard returned invalid JSON: {lines[-1]!r}") from exc

if not isinstance(result, dict):
raise RuntimeError("KernelGuard returned a non-object JSON payload")
return result


def analyze_submission(code: str) -> dict[str, Any]:
# Always use the single-shot CLI path so KERNELGUARD_TIMEOUT_SEC is enforced.
return _analyze_with_cli(code)


def enforce_submission_precheck(code: str, file_name: str) -> dict[str, Any] | None:
if not _env_enabled("KERNELGUARD_ENABLED"):
return None

try:
result = analyze_submission(code)
except Exception as exc:
logger.warning("KernelGuard pre-check failed for %s", file_name, exc_info=exc)
if _fail_open_enabled():
return None
raise KernelBotError(
"KernelGuard pre-check is unavailable right now. Please try again later.",
code=503,
) from exc

classification = str(result.get("classification", "unknown"))
if result.get("should_filter"):
patterns = sorted(
{
str(item.get("pattern", "unknown"))
for item in result.get("matched_patterns", [])
if isinstance(item, dict)
}
)
reason = str(result.get("filter_reason") or classification)
details = f"Submission rejected by KernelGuard pre-check ({reason})"
if patterns:
details += f". Matched rules: {', '.join(patterns)}"
raise KernelGuardRejected(details + ".", result=result)

if classification != "valid":
logger.info("KernelGuard classified %s as %s", file_name, classification)

return result
31 changes: 31 additions & 0 deletions src/libkernelbot/leaderboard_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,37 @@ def mark_submission_done(
self.connection.rollback() # Ensure rollback if error occurs
raise KernelBotError("Error while finalizing submission") from e

def mark_submission_hacked(self, submission: int, error: str | None = None) -> None:
try:
now = datetime.datetime.now(datetime.timezone.utc)
self.cursor.execute(
"""
UPDATE leaderboard.submission
SET done = TRUE, status = 'hacked'
WHERE id = %s
""",
(submission,),
)
self.cursor.execute(
"""
INSERT INTO leaderboard.submission_job_status AS s
(submission_id, status, error, last_heartbeat)
VALUES
(%s, %s, %s, %s)
ON CONFLICT (submission_id) DO UPDATE
SET
status = EXCLUDED.status,
error = COALESCE(EXCLUDED.error, s.error),
last_heartbeat = EXCLUDED.last_heartbeat
""",
(submission, "hacked", error, now),
)
self.connection.commit()
except psycopg2.Error as e:
logger.error("Could not mark submission '%s' as hacked.", submission, exc_info=e)
self.connection.rollback()
raise KernelBotError("Error while recording hacked submission") from e

def update_heartbeat_if_active(self, sub_id: int, ts: datetime.datetime) -> None:
try:
self.cursor.execute(
Expand Down
Loading
Loading