Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -323,4 +323,4 @@ def load_weights(self, rng_key: jax.Array) -> None:
model = model_creation_utils.from_pretrained(
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
self.model = nnx.data(model)
self.model = nnx.data(model)
42 changes: 42 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,48 @@
import jax
import importlib.util

# --- Monkeypatch for absl.testing.parameterized ---
# Context: Decorating a test method with @parameterized.named_parameters returns a custom
# iterable container (_ParameterizedTestIter) instead of a standard function object.
# Problem: When pytest markers are applied above @parameterized in the decorator stack:
#
# @pytest.mark.cpu_only
# @parameterized.named_parameters(...)
# def test_foo(self, ...):
#
# pytest attaches the marker attributes exclusively to the outer iterable container object.
# During class initialization, the test metaclass unwraps the base function to generate
# individual test methods, omitting the outer container entirely. Consequently, marker
# attributes attached to the outer container are dropped and lost before pytest collection.
# Solution: Intercept _ParameterizedTestIter.__iter__ to dynamically propagate any discovered
# pytestmark attributes from the outer container object down to all generated test methods.
from absl.testing import parameterized

try:
# pylint: disable=protected-access
_orig_iter = parameterized._ParameterizedTestIter.__iter__

def _custom_iter(self):
"""Custom iterator propagating outer pytestmark attributes to generated test methods."""
outer_marks = getattr(self, "pytestmark", None)
if outer_marks is None:
yield from _orig_iter(self)
else:
if not isinstance(outer_marks, list):
outer_marks = [outer_marks]

for func in _orig_iter(self):
existing_marks = getattr(func, "pytestmark", [])
if not isinstance(existing_marks, list):
existing_marks = [existing_marks]
func.pytestmark = existing_marks + outer_marks
yield func

parameterized._ParameterizedTestIter.__iter__ = _custom_iter
# pylint: enable=protected-access
except AttributeError:
pass

try:
_HAS_TPU = any(d.platform == "tpu" for d in jax.devices())
except Exception: # pragma: no cover pylint: disable=broad-exception-caught
Expand Down
60 changes: 60 additions & 0 deletions tests/unit/marker_propagation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests validating pytest marker propagation through decorator stacks."""

import functools
import unittest

from absl.testing import parameterized
import jax
import pytest


def dummy_decorator(func):
"""Standard transparent wrapper decorator preserving function metadata."""

@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper


Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 Excellent addition of a dedicated test to verify marker propagation. This ensures the monkeypatch works as expected and provides a regression test for any future changes to the testing infrastructure.

class MarkerPropagationTest(parameterized.TestCase):
"""Validates that pytest markers propagate correctly through decorator stacks."""

@pytest.mark.cpu_only
@parameterized.named_parameters(
{"testcase_name": "default", "unused": None},
)
def test_parameterized_cpu_only_marker_propagation(self, unused):
"""Verifies cpu_only marker above @parameterized propagates to generated methods."""
has_tpu = any(d.platform == "tpu" for d in jax.devices())
has_gpu = any(d.platform == "gpu" for d in jax.devices())
assert not has_tpu, "cpu_only parameterized test accidentally executed on TPU hardware"
assert not has_gpu, "cpu_only parameterized test accidentally executed on GPU hardware"

@pytest.mark.cpu_only
@dummy_decorator
def test_standard_decorator_cpu_only_marker_propagation(self):
"""Verifies cpu_only marker above standard decorators propagates correctly."""
has_tpu = any(d.platform == "tpu" for d in jax.devices())
has_gpu = any(d.platform == "gpu" for d in jax.devices())
assert not has_tpu, "cpu_only standard decorated test accidentally executed on TPU hardware"
assert not has_gpu, "cpu_only standard decorated test accidentally executed on GPU hardware"


if __name__ == "__main__":
unittest.main()
Loading