From a9b153ba14695fc06fb1e6d05e02a6f6199c4ba9 Mon Sep 17 00:00:00 2001 From: Ayush Agrawal Date: Fri, 1 May 2026 14:54:32 -0700 Subject: [PATCH] chore: chore PiperOrigin-RevId: 908911389 --- noxfile.py | 120 +++++++++++++++++++++++++++++++++++++++-------------- setup.py | 1 + 2 files changed, 89 insertions(+), 32 deletions(-) diff --git a/noxfile.py b/noxfile.py index ce2d2c3865..36177a06b1 100644 --- a/noxfile.py +++ b/noxfile.py @@ -69,6 +69,7 @@ "pytest-asyncio", # Preventing: py.test: error: unrecognized arguments: -n=auto --dist=loadscope "pytest-xdist", + "pytest-shard", ] UNIT_TEST_EXTERNAL_DEPENDENCIES = [] UNIT_TEST_LOCAL_DEPENDENCIES = [] @@ -196,41 +197,96 @@ def install_unittest_dependencies(session, *constraints): def default(session): - # Install all test dependencies, then install this package in-place. + # Install all test dependencies, then install this package in-place. - constraints_path = str( + constraints_path = str( CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" ) - install_unittest_dependencies(session, "-c", constraints_path) - - # Run py.test against the unit tests. - session.run( - "py.test", - "--quiet", - f"--junitxml=unit_{session.python}_sponge_log.xml", - "--cov=google", - "--cov-append", - "--cov-config=.coveragerc", - "--cov-report=", - "--cov-fail-under=0", - "--ignore=tests/unit/vertex_ray", - "--ignore=tests/unit/vertex_adk", - "--ignore=tests/unit/vertex_langchain", - "--ignore=tests/unit/vertex_ag2", - "--ignore=tests/unit/vertex_llama_index", - "--ignore=tests/unit/architecture", - os.path.join("tests", "unit"), - *session.posargs, - ) - - # Run tests that require isolation. - session.run( - "py.test", - "--quiet", - f"--junitxml=unit_{session.python}_test_vertexai_import_sponge_log.xml", - os.path.join("tests", "unit", "architecture", "test_vertexai_import.py"), - *session.posargs, - ) + install_unittest_dependencies(session, "-c", constraints_path) + + ml_framework_tests = [ + os.path.join("tests", "unit", "aiplatform", "test_uploader.py"), + os.path.join("tests", "unit", "aiplatform", "test_metadata_models.py"), + os.path.join("tests", "unit", "aiplatform", "test_metadata.py"), + os.path.join("tests", "unit", "aiplatform", "test_logdir_loader.py"), + os.path.join("tests", "unit", "aiplatform", "test_uploader_utils.py"), + os.path.join("tests", "unit", "aiplatform", "test_explain_lit.py"), + os.path.join( + "tests", + "unit", + "aiplatform", + "test_explain_saved_model_metadata_builder_tf1_test.py", + ), + os.path.join( + "tests", + "unit", + "aiplatform", + "test_explain_saved_model_metadata_builder_tf2_test.py", + ), + ] + + shard_id = os.environ.get("PYTEST_SHARD_ID") + shard_count = os.environ.get("PYTEST_SHARD_COUNT") + shard_args = [] + if shard_id and shard_count: + shard_args = [f"--shard-id={shard_id}", f"--shard-count={shard_count}"] + + core_pytest_args = [ + "py.test", + "--quiet", + f"--junitxml=unit_core_{session.python}_sponge_log.xml", + "--cov=google", + "--cov-append", + "--cov-config=.coveragerc", + "--cov-report=", + "--cov-fail-under=0", + "--ignore=tests/unit/vertex_ray", + "--ignore=tests/unit/vertex_adk", + "--ignore=tests/unit/vertex_langchain", + "--ignore=tests/unit/vertex_ag2", + "--ignore=tests/unit/vertex_llama_index", + "--ignore=tests/unit/architecture", + ] + + for ml_test in ml_framework_tests: + core_pytest_args.append(f"--ignore={ml_test}") + + core_pytest_args.extend(shard_args) + core_pytest_args.append(os.path.join("tests", "unit")) + core_pytest_args.extend(session.posargs) + + # Run py.test against the core unit tests. + session.run(*core_pytest_args) + + ml_pytest_args = [ + "py.test", + "--quiet", + f"--junitxml=unit_ml_{session.python}_sponge_log.xml", + "--cov=google", + "--cov-append", + "--cov-config=.coveragerc", + "--cov-report=", + "--cov-fail-under=0", + ] + + ml_pytest_args.extend(shard_args) + ml_pytest_args.extend(ml_framework_tests) + ml_pytest_args.extend(session.posargs) + + ml_env = os.environ.copy() + ml_env["PYTEST_ADDOPTS"] = "-n=4 --dist=loadscope" + + # Run py.test against the ML unit tests. + session.run(*ml_pytest_args, env=ml_env) + + # Run tests that require isolation. + session.run( + "py.test", + "--quiet", + f"--junitxml=unit_{session.python}_test_vertexai_import_sponge_log.xml", + os.path.join("tests", "unit", "architecture", "test_vertexai_import.py"), + *session.posargs, + ) @nox.session(python=UNIT_TEST_PYTHON_VERSIONS) diff --git a/setup.py b/setup.py index 6860a4c993..f892e235f0 100644 --- a/setup.py +++ b/setup.py @@ -273,6 +273,7 @@ "pytest-cov", "mock", "pytest-xdist", + "pytest-shard", "Pillow", "scikit-learn<1.6.0; python_version<='3.10'", "scikit-learn; python_version>'3.10'",