Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 274 additions & 2 deletions traincheck/instrumentor/source_file.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import ast
import logging
import re

import tokenize
import io
from traincheck.config.config import INSTR_MODULES_TO_INSTR
from collections import deque
from typing import Dict, Set

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -502,6 +505,274 @@ def instrument_model_tracker_sampler(
return source


def annotate_stage(
source: str,
) -> str:

def _ctx(msg: str) -> str:
return f"[annotate_stage] {msg}"

def has_stage(src: str, name: str) -> bool:
return re.search(rf'annotate_stage\(\s*[\'"]{name}[\'"]\s*\)', src) is not None
Comment thread
Essoz marked this conversation as resolved.

orig_has = {
"init": has_stage(source, "init"),
"training": has_stage(source, "training"),
"testing": has_stage(source, "testing"),
"checkpointing": has_stage(source, "checkpointing"),
}
orig_has_any = any(orig_has.values()) or ("annotate_stage(" in source)

for stage_name, present in orig_has.items():
if present:
logger.info(
_ctx(
f"Stage '{stage_name}' already present in source; skip adding this stage."
)
)

training_lines: Set[int] = set()
testing_lines: Set[int] = set()
checkpointing_lines: Set[int] = set()

q = deque(maxlen=3)
for tok in tokenize.generate_tokens(io.StringIO(source).readline):
q.append(tok)
if len(q) < 2:
continue
a = q[-3] if len(q) >= 3 else None
b = q[-2]
c = q[-1]

def at_attr(name: str) -> bool:
return (
a is not None
and a.type == tokenize.OP
and a.string == "."
and b.type == tokenize.NAME
and b.string == name
and c.type == tokenize.OP
and c.string == "("
)

if (at_attr("train") or at_attr("step")) and not orig_has["training"]:
training_lines.add(b.start[0])

if (at_attr("eval") or at_attr("no_grad")) and not orig_has["testing"]:
testing_lines.add(b.start[0])

if at_attr("save") and not orig_has["checkpointing"]:
checkpointing_lines.add(b.start[0])

priority = {"training": 3, "testing": 2, "checkpointing": 1}
Comment thread
Essoz marked this conversation as resolved.
Outdated
line_to_stage: Dict[int, str] = {}
for ln in checkpointing_lines:
line_to_stage[ln] = "checkpointing"
for ln in training_lines:
if priority["training"] > priority.get(line_to_stage.get(ln, ""), 0):
line_to_stage[ln] = "training"
for ln in testing_lines:
if priority["testing"] > priority.get(line_to_stage.get(ln, ""), 0):
line_to_stage[ln] = "testing"

lines = source.splitlines(keepends=True)
new_lines = []
inserted_count = {
"training": 0,
"testing": 0,
"checkpointing": 0,
"init": 0,
"import": 0,
}
for i, line in enumerate(lines):
lineno = i + 1
stage = line_to_stage.get(lineno)
if stage:
k = len(new_lines) - 1
while k >= 0 and new_lines[k].strip() == "":
k -= 1
prev = new_lines[k] if k >= 0 else ""
if not (
("annotate_stage" in prev)
and (f'"{stage}"' in prev or f"'{stage}'" in prev)
):
indent = re.match(r"\s*", line).group(0)
new_lines.append(f'{indent}annotate_stage("{stage}")\n')
inserted_count[stage] += 1
logger.info(
_ctx(
f"Inserted stage '{stage}' before line {lineno}: {line.strip()}"
)
)
else:
logger.info(
_ctx(
f"Skip inserting '{stage}' at line {lineno} (previous non-empty line already has it)."
)
)
new_lines.append(line)

new_src = "".join(new_lines)

def _find_annotate_import_idx(lines):
for idx, l in enumerate(lines):
if re.match(r"^\s*from\s+traincheck\s+import\s+annotate_stage\s*$", l):
return idx
return -1

lines_list = new_src.splitlines(keepends=True)
annot_import_idx = _find_annotate_import_idx(lines_list)

if annot_import_idx == -1:
insert_idx = 0
while insert_idx < len(lines_list):
s = lines_list[insert_idx].strip()
if (
lines_list[insert_idx].startswith("#!")
or (s.startswith("#") and "coding" in s)
or s.startswith("from __future__ import")
):
insert_idx += 1
else:
break
lines_list.insert(insert_idx, "from traincheck import annotate_stage\n")
annot_import_idx = insert_idx
inserted_count["import"] += 1
logger.info(
_ctx(
f"Inserted import 'from traincheck import annotate_stage' at line {annot_import_idx + 1}."
)
)

new_src = "".join(lines_list)

if not orig_has["init"]:
has_guard = (
re.search(
r'^\s*if\s+__name__\s*==\s*[\'"]__main__[\'"]\s*:\s*$', new_src, re.M
)
is not None
)
main_def = re.search(
r"^([ \t]*)def\s+main\s*\(.*?\)\s*:\s*(?:#.*)?$", new_src, re.M
)

if has_guard and main_def:
def_line_start = main_def.start()
before_def = new_src[:def_line_start]
def_line_idx = before_def.count("\n")
indent = main_def.group(1)
step = "\t" if ("\t" in indent and " " not in indent) else " "
body_indent = indent + step

nl = new_src.splitlines(keepends=True)
insert_at = def_line_idx + 1
while insert_at < len(nl) and nl[insert_at].strip() == "":
insert_at += 1

def _is_triple_quote(s: str) -> bool:
t = s.lstrip()
return t.startswith('"""') or t.startswith("'''")

if insert_at < len(nl) and _is_triple_quote(nl[insert_at]):
quote = '"""' if nl[insert_at].lstrip().startswith('"""') else "'''"
if nl[insert_at].count(quote) >= 2 and nl[
insert_at
].lstrip().startswith(quote):
Comment thread
Essoz marked this conversation as resolved.
Outdated
insert_at += 1
else:
insert_at += 1
while insert_at < len(nl):
if quote in nl[insert_at]:
insert_at += 1
break
insert_at += 1

k = insert_at - 1
while k >= 0 and nl[k].strip() == "":
k -= 1
prev = nl[k] if k >= 0 else ""
if not (("annotate_stage" in prev) and ("init" in prev)):
nl.insert(insert_at, f'{body_indent}annotate_stage("init")\n')
inserted_count["init"] += 1
logger.info(
_ctx(
f"Inserted stage 'init' at start of main() body (line {insert_at + 1})."
)
)
else:
logger.info(
_ctx(
"Skip inserting 'init' inside main(): previous non-empty line already has it."
)
)
new_src = "".join(nl)
else:
lines2 = new_src.splitlines(keepends=True)
annot_import_idx = _find_annotate_import_idx(lines2)
if annot_import_idx == -1:
i = 0
while i < len(lines2):
s = lines2[i].strip()
if (
lines2[i].startswith("#!")
or (s.startswith("#") and "coding" in s)
or s.startswith("from __future__ import")
):
i += 1
else:
break
while i < len(lines2):
s = lines2[i].strip()
if (
s.startswith("import ")
or s.startswith("from ")
or s == ""
or s.startswith("#")
):
i += 1
else:
break
insert_at = i
else:
insert_at = annot_import_idx + 1

k = insert_at
while k < len(lines2) and lines2[k].strip() == "":
k += 1
next_line = lines2[k] if k < len(lines2) else ""
if not (("annotate_stage" in next_line) and ("init" in next_line)):
lines2.insert(insert_at, 'annotate_stage("init")\n')
inserted_count["init"] += 1
logger.info(
_ctx(
f"Inserted stage 'init' right after annotate_stage import at line {insert_at + 1}."
)
)
else:
logger.info(
_ctx(
"Skip inserting 'init': next non-empty line after annotate_stage import is already init."
)
)

new_src = "".join(lines2)

if "annotate_stage(" not in new_src and not orig_has_any:
logger.error(
_ctx(
"Automatic insertion failed: no annotate_stage(...) found or added. Manual insertion required."
)
)
raise RuntimeError(
_ctx(
"Automatic insertion failed: no annotate_stage(...) found or added. Manual insertion required."
Comment thread
Essoz marked this conversation as resolved.
Outdated
)
)

return new_src


def instrument_file(
path: str,
modules_to_instr: list[str],
Expand Down Expand Up @@ -532,7 +803,8 @@ def instrument_file(
funcs_to_instr,
API_dump_stack_trace,
)

# annotate stages
instrumented_source = annotate_stage(instrumented_source)
# logging configs
logging_start_code = f"""
import os
Expand Down
Loading