diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index a13dfada79..cd2d85c91c 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -48,7 +48,11 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $T NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" -NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" +export NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint +if [ ! -d "$NVTE_TEST_CHECKPOINT_ARTIFACT_PATH" ]; then + python3 $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all || error_exit "Failed to generate checkpoint files" +fi +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" diff --git a/tests/pytorch/test_checkpoint.py b/tests/pytorch/test_checkpoint.py index 1383264fdc..0427886b84 100644 --- a/tests/pytorch/test_checkpoint.py +++ b/tests/pytorch/test_checkpoint.py @@ -101,7 +101,7 @@ def _save_checkpoint(name: str, checkpoint_dir: Optional[pathlib.Path] = None) - # Path to save checkpoint if checkpoint_dir is None: checkpoint_dir = TestLoadCheckpoint._checkpoint_dir() - checkpoint_dir.mkdir(exist_ok=True) + checkpoint_dir.mkdir(parents=True, exist_ok=True) checkpoint_file = checkpoint_dir / f"{name}.pt" # Create module and save checkpoint