Skip to content

Commit af8dff8

Browse files
spikerheado1234Ubuntu
authored andcommitted
add correctness check
fix ASCII correctness change precision check to 1e-2
1 parent ee0372e commit af8dff8

File tree

3 files changed

+477
-0
lines changed

3 files changed

+477
-0
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#!/bin/bash
2+
3+
# Correctness test suite for autosp vs baseline compiled DS-Ulysses.
4+
#
5+
# For each (sp_size, zero_stage) configuration:
6+
# 1. Runs baseline (--compile compile) for N steps
7+
# 2. Runs autosp (--compile autosp) for N steps
8+
# 3. Compares per-rank losses with validator.py
9+
#
10+
# Usage:
11+
# ./correctness.sh # Test sp-sizes 1, 2, 4, 8
12+
# ./correctness.sh 1 2 8 # Test only sp-sizes 1, 2, 8
13+
14+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
15+
OUTPUT_DIR="${SCRIPT_DIR}/output"
16+
STEPS=5
17+
18+
cleanup() {
19+
rm -rf "${OUTPUT_DIR}"
20+
}
21+
trap cleanup EXIT
22+
23+
# Parse sp-sizes from positional args; default to 1 2 4 8
24+
if [ $# -gt 0 ]; then
25+
SP_SIZES=("$@")
26+
else
27+
SP_SIZES=(1 2 4 8)
28+
fi
29+
30+
ZERO_STAGES=(0 1)
31+
32+
PASS_COUNT=0
33+
FAIL_COUNT=0
34+
TOTAL_COUNT=0
35+
declare -a RESULTS=()
36+
37+
echo ""
38+
echo "================================================================"
39+
echo " AutoSP Correctness Test Suite"
40+
echo "================================================================"
41+
echo " SP sizes: ${SP_SIZES[*]}"
42+
echo " Zero stages: ${ZERO_STAGES[*]}"
43+
echo " Steps: ${STEPS}"
44+
echo " Output dir: ${OUTPUT_DIR}"
45+
echo "================================================================"
46+
echo ""
47+
48+
for sp_size in "${SP_SIZES[@]}"; do
49+
for zero_stage in "${ZERO_STAGES[@]}"; do
50+
TEST_NAME="sp${sp_size}_zero${zero_stage}"
51+
TEST_DIR="${OUTPUT_DIR}/${TEST_NAME}"
52+
mkdir -p "${TEST_DIR}"
53+
54+
((TOTAL_COUNT++))
55+
56+
echo "----------------------------------------------------------------"
57+
echo " Test: sp_size=${sp_size}, zero_stage=${zero_stage}"
58+
echo "----------------------------------------------------------------"
59+
60+
# --- Baseline (compiled DS-Ulysses) ---
61+
echo " [1/3] Running baseline (--compile compile) ..."
62+
if ! python3 "${SCRIPT_DIR}/correctness_run.py" \
63+
--compile compile \
64+
--sp-size "${sp_size}" \
65+
--zero-stage "${zero_stage}" \
66+
--steps "${STEPS}" \
67+
--output-file "${TEST_DIR}/baseline.json"; then
68+
69+
echo " FAIL: Baseline training failed"
70+
RESULTS+=(" ${TEST_NAME}: FAIL (baseline training error)")
71+
((FAIL_COUNT++))
72+
echo ""
73+
continue
74+
fi
75+
76+
# --- AutoSP ---
77+
echo " [2/3] Running autosp (--compile autosp) ..."
78+
if ! python3 "${SCRIPT_DIR}/correctness_run.py" \
79+
--compile autosp \
80+
--sp-size "${sp_size}" \
81+
--zero-stage "${zero_stage}" \
82+
--steps "${STEPS}" \
83+
--output-file "${TEST_DIR}/autosp.json"; then
84+
85+
echo " FAIL: AutoSP training failed"
86+
RESULTS+=(" ${TEST_NAME}: FAIL (autosp training error)")
87+
((FAIL_COUNT++))
88+
echo ""
89+
continue
90+
fi
91+
92+
# --- Validate ---
93+
echo " [3/3] Validating per-rank losses ..."
94+
if python3 "${SCRIPT_DIR}/validator.py" \
95+
--baseline "${TEST_DIR}/baseline.json" \
96+
--autosp "${TEST_DIR}/autosp.json"; then
97+
98+
RESULTS+=(" ${TEST_NAME}: PASS")
99+
((PASS_COUNT++))
100+
else
101+
RESULTS+=(" ${TEST_NAME}: FAIL")
102+
((FAIL_COUNT++))
103+
fi
104+
105+
echo ""
106+
done
107+
done
108+
109+
# ---- Summary ----
110+
echo "================================================================"
111+
echo " SUMMARY"
112+
echo "================================================================"
113+
for result in "${RESULTS[@]}"; do
114+
echo "${result}"
115+
done
116+
echo ""
117+
echo " Passed: ${PASS_COUNT}/${TOTAL_COUNT} Failed: ${FAIL_COUNT}/${TOTAL_COUNT}"
118+
echo "================================================================"
119+
120+
if [ "${FAIL_COUNT}" -gt 0 ]; then
121+
exit 1
122+
fi
123+
exit 0
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
"""
2+
Runs training for a specific configuration (compile mode, sp_size, zero_stage)
3+
and saves per-rank losses to a JSON file.
4+
5+
Reuses the existing run.py training script with temporary config files,
6+
launching via accelerate in the same way as run_autosp.sh.
7+
"""
8+
9+
import argparse
10+
import csv
11+
import json
12+
import os
13+
import re
14+
import socket
15+
import subprocess
16+
import sys
17+
import tempfile
18+
19+
20+
def get_free_port():
21+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
22+
s.bind(("", 0))
23+
return s.getsockname()[1]
24+
25+
26+
def get_host_ip():
27+
try:
28+
result = subprocess.run(
29+
["hostname", "-i"], capture_output=True, text=True, check=True
30+
)
31+
return result.stdout.strip().split()[0]
32+
except Exception:
33+
return "127.0.0.1"
34+
35+
36+
def create_ds_config(compile_mode, sp_size, zero_stage, config_path):
37+
"""Create a DeepSpeed JSON config for the given configuration."""
38+
config = {
39+
"bf16": {"enabled": True},
40+
"zero_optimization": {"stage": zero_stage},
41+
"gradient_accumulation_steps": 1,
42+
"gradient_clipping": "auto",
43+
"steps_per_print": 2000,
44+
"train_batch_size": "auto",
45+
"train_micro_batch_size_per_gpu": "auto",
46+
"wall_clock_breakdown": False,
47+
}
48+
if compile_mode == "autosp":
49+
config["compile"] = {
50+
"deepcompile": True,
51+
"passes": ["autosp"],
52+
"pass_args": {"sp_size": sp_size},
53+
}
54+
with open(config_path, "w") as f:
55+
json.dump(config, f, indent=4)
56+
57+
58+
def create_accelerate_config(ds_config_path, sp_size, config_path):
59+
"""Create an accelerate YAML config pointing to the DS JSON config."""
60+
content = (
61+
"compute_environment: LOCAL_MACHINE\n"
62+
"debug: false\n"
63+
"deepspeed_config:\n"
64+
" deepspeed_multinode_launcher: standard\n"
65+
f" deepspeed_config_file: {ds_config_path}\n"
66+
"distributed_type: DEEPSPEED\n"
67+
"machine_rank: 0\n"
68+
"main_training_function: main\n"
69+
"num_machines: 1\n"
70+
f"num_processes: {sp_size}\n"
71+
"rdzv_backend: static\n"
72+
"same_network: true\n"
73+
"tpu_env: []\n"
74+
"tpu_use_cluster: false\n"
75+
"tpu_use_sudo: false\n"
76+
"use_cpu: false\n"
77+
)
78+
with open(config_path, "w") as f:
79+
f.write(content)
80+
81+
82+
def parse_losses_from_csv(logs_dir, compile_mode, seq_length, sp_size):
83+
"""Read per-rank loss CSV files written by run.py (full precision)."""
84+
losses = {}
85+
for rank in range(sp_size):
86+
csv_path = os.path.join(
87+
logs_dir, f"loss_{compile_mode}_seq{seq_length}_rank{rank}.csv"
88+
)
89+
if not os.path.exists(csv_path):
90+
continue
91+
rank_losses = {}
92+
with open(csv_path, "r") as f:
93+
reader = csv.DictReader(f)
94+
for row in reader:
95+
rank_losses[str(row["step"])] = float(row["loss"])
96+
losses[str(rank)] = rank_losses
97+
return losses
98+
99+
100+
def parse_losses_from_stdout(output):
101+
"""Fallback: parse loss values from the printed training output."""
102+
losses = {}
103+
for line in output.split("\n"):
104+
match = re.search(r"\[Rank (\d+)\].*Step (\d+), Loss: ([\d.]+)", line)
105+
if match:
106+
rank, step = match.group(1), match.group(2)
107+
loss = float(match.group(3))
108+
losses.setdefault(rank, {})[step] = loss
109+
return losses
110+
111+
112+
def cleanup_csv_files(logs_dir, compile_mode, seq_length, sp_size):
113+
"""Remove loss CSV files created by run.py during training."""
114+
for rank in range(sp_size):
115+
csv_path = os.path.join(
116+
logs_dir, f"loss_{compile_mode}_seq{seq_length}_rank{rank}.csv"
117+
)
118+
try:
119+
os.remove(csv_path)
120+
except FileNotFoundError:
121+
pass
122+
123+
124+
def main():
125+
parser = argparse.ArgumentParser(
126+
description="Run training and capture per-rank losses"
127+
)
128+
parser.add_argument("--compile", choices=["compile", "autosp"], required=True)
129+
parser.add_argument("--sp-size", type=int, required=True)
130+
parser.add_argument("--zero-stage", type=int, choices=[0, 1], required=True)
131+
parser.add_argument("--steps", type=int, default=5)
132+
parser.add_argument("--output-file", type=str, required=True)
133+
parser.add_argument("--seq-length", type=int, default=64)
134+
parser.add_argument("--batch-size", type=int, default=1)
135+
parser.add_argument("--num-layers", type=int, default=1)
136+
parser.add_argument("--verbose", action="store_true")
137+
args = parser.parse_args()
138+
139+
script_dir = os.path.dirname(os.path.abspath(__file__))
140+
autosp_dir = os.path.abspath(os.path.join(script_dir, ".."))
141+
run_py = os.path.join(autosp_dir, "run.py")
142+
logs_dir = os.path.join(autosp_dir, "logs")
143+
144+
output_dir = os.path.dirname(os.path.abspath(args.output_file))
145+
os.makedirs(output_dir, exist_ok=True)
146+
147+
with tempfile.TemporaryDirectory() as tmpdir:
148+
ds_config_path = os.path.join(tmpdir, "ds_config.json")
149+
accel_config_path = os.path.join(tmpdir, "accelerate_config.yaml")
150+
151+
create_ds_config(args.compile, args.sp_size, args.zero_stage, ds_config_path)
152+
create_accelerate_config(ds_config_path, args.sp_size, accel_config_path)
153+
154+
host_ip = get_host_ip()
155+
port = get_free_port()
156+
157+
cmd = [
158+
"accelerate", "launch",
159+
"--main_process_ip", host_ip,
160+
"--main_process_port", str(port),
161+
"--num_machines", "1",
162+
"--num_processes", str(args.sp_size),
163+
"--machine_rank", "0",
164+
"--config_file", accel_config_path,
165+
run_py,
166+
"--model_name", "meta-llama/Llama-2-7b-chat-hf",
167+
"--batch_size", str(args.batch_size),
168+
"--seq_length", str(args.seq_length),
169+
"--sp_size", str(args.sp_size),
170+
"--dp_size", "1",
171+
"--backend", "inductor",
172+
"--compile", args.compile,
173+
"--num_layers", str(args.num_layers),
174+
"--steps", str(args.steps),
175+
"--deterministic",
176+
]
177+
178+
env = os.environ.copy()
179+
env["NCCL_DEBUG"] = "WARN"
180+
181+
output = ""
182+
stderr_output = ""
183+
184+
if args.verbose:
185+
process = subprocess.Popen(
186+
cmd,
187+
stdout=subprocess.PIPE,
188+
stderr=subprocess.STDOUT,
189+
text=True,
190+
cwd=autosp_dir,
191+
env=env,
192+
)
193+
for line in process.stdout:
194+
output += line
195+
sys.stdout.write(line)
196+
sys.stdout.flush()
197+
process.wait()
198+
return_code = process.returncode
199+
else:
200+
result = subprocess.run(
201+
cmd, capture_output=True, text=True, cwd=autosp_dir, env=env
202+
)
203+
output = result.stdout
204+
stderr_output = result.stderr
205+
return_code = result.returncode
206+
207+
# Save training log for debugging
208+
log_path = args.output_file.replace(".json", ".log")
209+
with open(log_path, "w") as f:
210+
f.write(f"Command: {' '.join(cmd)}\n")
211+
f.write(f"Return code: {return_code}\n")
212+
f.write("=" * 60 + "\n")
213+
f.write(output)
214+
if stderr_output:
215+
f.write("\n--- STDERR ---\n")
216+
f.write(stderr_output)
217+
218+
if return_code != 0:
219+
print(f" Training failed (exit code {return_code}). See: {log_path}")
220+
if not args.verbose:
221+
lines = (output + stderr_output).strip().split("\n")
222+
for line in lines[-30:]:
223+
print(f" {line}")
224+
cleanup_csv_files(logs_dir, args.compile, args.seq_length, args.sp_size)
225+
sys.exit(1)
226+
227+
# Read full-precision losses from CSV files written by run.py
228+
losses = parse_losses_from_csv(
229+
logs_dir, args.compile, args.seq_length, args.sp_size
230+
)
231+
cleanup_csv_files(logs_dir, args.compile, args.seq_length, args.sp_size)
232+
233+
if not losses:
234+
print(" Warning: CSV loss files not found, falling back to stdout parsing")
235+
losses = parse_losses_from_stdout(output)
236+
237+
if not losses:
238+
print(" Error: No losses found in training output")
239+
sys.exit(1)
240+
241+
result_data = {
242+
"config": {
243+
"compile": args.compile,
244+
"sp_size": args.sp_size,
245+
"zero_stage": args.zero_stage,
246+
"steps": args.steps,
247+
},
248+
"losses": losses,
249+
}
250+
251+
with open(args.output_file, "w") as f:
252+
json.dump(result_data, f, indent=2)
253+
254+
num_ranks = len(losses)
255+
num_steps = max(len(v) for v in losses.values())
256+
print(f" Losses saved: {num_ranks} rank(s), {num_steps} step(s) -> {args.output_file}")
257+
258+
259+
if __name__ == "__main__":
260+
main()

0 commit comments

Comments
 (0)