Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
fc843ed
Add support to download models automatically if --model specified in …
rohan-uiuc Oct 30, 2025
5f790ff
create model dir if it doesn't exist
rohan-uiuc Oct 30, 2025
0f22bec
Check model weights existence before binding; use HF model name if mi…
rohan-uiuc Nov 12, 2025
9f2fdd2
Remove commented code
rohan-uiuc Nov 12, 2025
38011be
Apply code formatting fixes from pre-commit
rohan-uiuc Nov 12, 2025
4de3563
revert unnecessary test change
rohan-uiuc Nov 12, 2025
eb1e929
Merge branch 'develop' of https://github.com/Center-for-AI-Innovation…
rohan-uiuc Nov 12, 2025
8b6a211
Apply formatting fixes from pre-commit
rohan-uiuc Nov 12, 2025
c68cb35
Add tests for model weights existence coverage
rohan-uiuc Nov 12, 2025
b610891
Remove redundant /dev/infiniband
rohan-uiuc Jan 5, 2026
a7a5deb
Remove unused variable
rohan-uiuc Jan 5, 2026
bb3142b
Add warning if downloading weights and HF cache not set
rohan-uiuc Jan 5, 2026
2db9c6c
format ONLY
rohan-uiuc Jan 5, 2026
079c86a
Add --hf-model CLI option and config field
rohan-uiuc Jan 29, 2026
20163da
Use hf_model as model source when local weights missing
rohan-uiuc Jan 29, 2026
ed82b77
Pass hf_model from CLI to launch params
rohan-uiuc Jan 29, 2026
d3f6772
Add tests for hf_model override in slurm script generation
rohan-uiuc Jan 29, 2026
1c312df
Add documentation for --hf-model option
rohan-uiuc Jan 29, 2026
cd7a9b1
Rebase with main and update hf_model field processing logic
XkunW Mar 29, 2026
03864a5
Fix typos
XkunW Mar 30, 2026
292200f
Fix tests
XkunW Mar 30, 2026
ba9030f
ruff check & format
XkunW Mar 30, 2026
b4037e6
Merge branch 'main' into hf_download
XkunW Mar 30, 2026
8e66e26
ruff format
XkunW Mar 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tests/vec_inf/client/test_slurm_script_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
)


@pytest.fixture(autouse=True)
def patch_model_weights_exists(monkeypatch):
"""Ensure model weights directory existence checks default to True."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists", lambda self: True
)


class TestSlurmScriptGenerator:
"""Tests for SlurmScriptGenerator class."""

Expand Down Expand Up @@ -168,6 +176,21 @@ def test_generate_server_setup_singularity(self, singularity_params):
"module load " in setup
) # Remove module name since it's inconsistent between clusters

def test_generate_server_setup_singularity_no_weights(
self, singularity_params, monkeypatch
):
"""Test server setup when model weights don't exist."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists",
lambda self: False,
)

generator = SlurmScriptGenerator(singularity_params)
setup = generator._generate_server_setup()

assert "ray stop" in setup
assert "/path/to/model_weights/test-model" not in setup

def test_generate_launch_cmd_venv(self, basic_params):
"""Test launch command generation with virtual environment."""
generator = SlurmScriptGenerator(basic_params)
Expand All @@ -187,6 +210,22 @@ def test_generate_launch_cmd_singularity(self, singularity_params):
assert "apptainer exec --nv" in launch_cmd
assert "source" not in launch_cmd

def test_generate_launch_cmd_singularity_no_local_weights(
self, singularity_params, monkeypatch
):
"""Test container launch when model weights directory is missing."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists",
lambda self: False,
)

generator = SlurmScriptGenerator(singularity_params)
launch_cmd = generator._generate_launch_cmd()

assert "exec --nv" in launch_cmd
assert "--bind /path/to/model_weights/test-model" not in launch_cmd
assert "vllm serve test-model" in launch_cmd

def test_generate_launch_cmd_boolean_args(self, basic_params):
"""Test launch command with boolean vLLM arguments."""
params = basic_params.copy()
Expand Down Expand Up @@ -391,6 +430,24 @@ def test_generate_model_launch_script_singularity(
mock_touch.assert_called_once()
mock_write_text.assert_called_once()

@patch("pathlib.Path.touch")
@patch("pathlib.Path.write_text")
def test_generate_model_launch_script_singularity_no_weights(
self, mock_write_text, mock_touch, batch_singularity_params, monkeypatch
):
"""Test batch model launch script when model weights don't exist."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists",
lambda self: False,
)

generator = BatchSlurmScriptGenerator(batch_singularity_params)
script_path = generator._generate_model_launch_script("model1")

assert script_path.name == "launch_model1.sh"
call_args = mock_write_text.call_args[0][0]
assert "/path/to/model_weights/model1" not in call_args

@patch("vec_inf.client._slurm_script_generator.datetime")
@patch("pathlib.Path.touch")
@patch("pathlib.Path.write_text")
Expand Down
69 changes: 51 additions & 18 deletions vec_inf/client/_slurm_script_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,18 @@ def __init__(self, params: dict[str, Any]):
self.additional_binds = (
f",{self.params['bind']}" if self.params.get("bind") else ""
)
self.model_weights_path = str(
Path(self.params["model_weights_parent_dir"], self.params["model_name"])
model_weights_path = Path(
self.params["model_weights_parent_dir"], self.params["model_name"]
)
self.model_weights_exists = model_weights_path.exists()
self.model_weights_path = str(model_weights_path)
self.model_source = (
self.model_weights_path
if self.model_weights_exists
else self.params["model_name"]
)
self.model_bind_option = (
Comment thread
rohan-uiuc marked this conversation as resolved.
Outdated
f",{self.model_weights_path}" if self.model_weights_exists else ""
)
self.env_str = self._generate_env_str()

Expand Down Expand Up @@ -111,7 +121,9 @@ def _generate_server_setup(self) -> str:
server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"]))
server_script.append(
SLURM_SCRIPT_TEMPLATE["bind_path"].format(
model_weights_path=self.model_weights_path,
model_weights_path=self.model_weights_path
if self.model_weights_exists
else "",
additional_binds=self.additional_binds,
)
)
Expand All @@ -131,7 +143,6 @@ def _generate_server_setup(self) -> str:
server_setup_str = server_setup_str.replace(
"CONTAINER_PLACEHOLDER",
SLURM_SCRIPT_TEMPLATE["container_command"].format(
model_weights_path=self.model_weights_path,
env_str=self.env_str,
),
)
Expand Down Expand Up @@ -165,22 +176,27 @@ def _generate_launch_cmd(self) -> str:
Server launch command.
"""
launcher_script = ["\n"]

vllm_args_copy = self.params["vllm_args"].copy()
Comment thread
XkunW marked this conversation as resolved.
Outdated
model_source = self.model_source
if "--model" in vllm_args_copy:
model_source = vllm_args_copy.pop("--model")

if self.use_container:
launcher_script.append(
SLURM_SCRIPT_TEMPLATE["container_command"].format(
model_weights_path=self.model_weights_path,
env_str=self.env_str,
)
)

launcher_script.append(
"\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format(
model_weights_path=self.model_weights_path,
model_source=model_source,
model_name=self.params["model_name"],
)
)

for arg, value in self.params["vllm_args"].items():
for arg, value in vllm_args_copy.items():
if isinstance(value, bool):
launcher_script.append(f" {arg} \\")
else:
Expand Down Expand Up @@ -225,11 +241,20 @@ def __init__(self, params: dict[str, Any]):
if self.params["models"][model_name].get("bind")
else ""
)
self.params["models"][model_name]["model_weights_path"] = str(
Path(
self.params["models"][model_name]["model_weights_parent_dir"],
model_name,
)
model_weights_path = Path(
self.params["models"][model_name]["model_weights_parent_dir"],
model_name,
)
model_weights_exists = model_weights_path.exists()
model_weights_path_str = str(model_weights_path)
self.params["models"][model_name]["model_weights_path"] = (
model_weights_path_str
)
self.params["models"][model_name]["model_weights_exists"] = (
model_weights_exists
)
self.params["models"][model_name]["model_source"] = (
model_weights_path_str if model_weights_exists else model_name
)

def _write_to_log_dir(self, script_content: list[str], script_name: str) -> Path:
Expand Down Expand Up @@ -266,7 +291,9 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"])
script_content.append(
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format(
model_weights_path=model_params["model_weights_path"],
model_weights_path=model_params["model_weights_path"]
if model_params.get("model_weights_exists", True)
else "",
additional_binds=model_params["additional_binds"],
)
)
Expand All @@ -283,19 +310,25 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
model_name=model_name,
)
)
vllm_args_copy = model_params["vllm_args"].copy()
model_source = model_params.get(
"model_source", model_params["model_weights_path"]
)
if "--model" in vllm_args_copy:
model_source = vllm_args_copy.pop("--model")

if self.use_container:
script_content.append(
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format(
model_weights_path=model_params["model_weights_path"],
)
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format()
)
script_content.append(
"\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"]).format(
model_weights_path=model_params["model_weights_path"],
model_source=model_source,
model_name=model_name,
)
)
for arg, value in model_params["vllm_args"].items():

for arg, value in vllm_args_copy.items():
if isinstance(value, bool):
script_content.append(f" {arg} \\")
else:
Expand Down
6 changes: 3 additions & 3 deletions vec_inf/client/_slurm_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class SlurmScriptTemplate(TypedDict):
f"{CONTAINER_MODULE_NAME} exec {IMAGE_PATH} ray stop",
],
"imports": "source {src_dir}/find_port.sh",
"bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp,{{model_weights_path}}{{additional_binds}}",
"bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,$(echo /dev/infiniband* | sed -e 's/ /,/g'),/dev,/tmp{{model_weights_path}}{{additional_binds}}",
Comment thread
rohan-uiuc marked this conversation as resolved.
Outdated
"container_command": f"{CONTAINER_MODULE_NAME} exec --nv {{env_str}} --containall {IMAGE_PATH} \\",
"activate_venv": "source {venv}/bin/activate",
"server_setup": {
Expand Down Expand Up @@ -164,7 +164,7 @@ class SlurmScriptTemplate(TypedDict):
' && mv temp.json "$json_path"',
],
"launch_cmd": [
"vllm serve {model_weights_path} \\",
"vllm serve {model_source} \\",
" --served-model-name {model_name} \\",
' --host "0.0.0.0" \\',
" --port $vllm_port_number \\",
Expand Down Expand Up @@ -255,7 +255,7 @@ class BatchModelLaunchScriptTemplate(TypedDict):
],
"container_command": f"{CONTAINER_MODULE_NAME} exec --nv --containall {IMAGE_PATH} \\",
"launch_cmd": [
"vllm serve {model_weights_path} \\",
"vllm serve {model_source} \\",
" --served-model-name {model_name} \\",
' --host "0.0.0.0" \\',
" --port $vllm_port_number \\",
Expand Down