Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/maxtext/configs/post_train/rl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ rollout_expert_parallelism: 1
# ====== Reproducibility ======
data_shuffle_seed: 42

# ====== LoRA ======
# Low-Rank Adaptation for the actor model. When enabled, only the LoRA parameters
# are trained and checkpointed, significantly reducing memory and compute.
lora:
enabled: False
rank: 32
alpha: 64.0
# Regex matching module paths to apply LoRA to.
# Qwix uses re.fullmatch over slash-separated module paths like:
# layers/self_attention/query, layers/self_attention/key, ...
module_path: 'layers/(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))'

# ====== RL ======
# This config includes RL algorithm variations such as grpo or gspo-token
rl:
Expand Down
13 changes: 13 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1622,6 +1622,15 @@ class VLLM(BaseModel):
)
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")


class LoraConfig(BaseModel):
"""Configuration for LoRA (Low-Rank Adaptation) applied to the actor model during RL training."""

enabled: bool = Field(False, description="If True, apply LoRA to the actor model instead of full-weight training.")
rank: int = Field(4, description="Rank of the LoRA decomposition.")
alpha: float = Field(8.0, description="Alpha scaling parameter for LoRA (effective LR scale = alpha / rank).")
module_path: str = Field(".*", description="Regex matching which module paths to apply LoRA to.")


class RL(BaseModel):
"""Configuration for RL algorithms like Group Relative Policy Optimization (GRPO) among others."""
Expand Down Expand Up @@ -1948,6 +1957,10 @@ class MaxTextConfig(
default_factory=RL,
description="Configuration for RL algorithms like Group Relative Policy Optimization (GRPO).",
)
lora: LoraConfig = Field(
default_factory=LoraConfig,
description="Configuration for LoRA applied to the actor model during RL training.",
)
model_config = ConfigDict(extra="forbid", protected_namespaces=())

@model_validator(mode="before")
Expand Down
13 changes: 13 additions & 0 deletions src/maxtext/integration/tunix/tunix_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,16 @@ def lora_to_hf_mappings(self):
return {}

return self._vllm_weight_mapping.lora_to_hf_mappings()

def get_model_input(self):
"""Returns dummy inputs matching __call__ for qwix.apply_lora_to_model tracing."""
import jax.numpy as jnp # pylint: disable=import-outside-toplevel

dummy_batch_size = 1
dummy_seq_len = 128
return {
"input_tokens": jnp.ones((dummy_batch_size, dummy_seq_len), dtype=jnp.int32),
"positions": jnp.ones((dummy_batch_size, dummy_seq_len), dtype=jnp.int32),
"cache": None,
"attention_mask": None,
}
9 changes: 6 additions & 3 deletions src/maxtext/layers/nnx_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,12 @@ def __call__(
out, updates = self.to_nnx__module.init_with_output(_rngs, *args, method=method, **kwargs)
else:
nnx_attrs = {
k: v
for k, v in vars(self).items()
if not k.startswith("to_nnx__") and not k.startswith("_pytree__") and not k.startswith("_object__")
k: v
for k, v in vars(self).items()
if not k.startswith("to_nnx__")
and not k.startswith("_pytree__")
and not k.startswith("_object__")
and isinstance(v, dict) # linen variable collections are always dicts; skip plain attrs (e.g. qwix metadata)
}
variables = nnx_attrs_to_linen_vars(nnx_attrs)

Expand Down
18 changes: 17 additions & 1 deletion src/maxtext/trainers/post_train/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from orbax import checkpoint as ocp
from pprint import pprint
from transformers import AutoTokenizer
import qwix
from tunix.rl import rl_cluster as rl_cluster_lib
from tunix.rl.rollout import base_rollout
from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
Expand All @@ -79,7 +80,6 @@
from maxtext.input_pipeline.instruction_data_processing import load_template_from_file
from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils


def get_maxtext_model(config, devices=None):
"""
Load MaxText model with Tunix adapter.
Expand Down Expand Up @@ -419,6 +419,22 @@ def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sa
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)

# add lora adapter to actor model
if trainer_config.lora.enabled:
max_logging.log(
f"Applying LoRA to actor model: rank={trainer_config.lora.rank}, "
f"alpha={trainer_config.lora.alpha}, module_path='{trainer_config.lora.module_path}'"
)
lora_provider = qwix.LoraProvider(
module_path=trainer_config.lora.module_path,
rank=trainer_config.lora.rank,
alpha=trainer_config.lora.alpha,
)
model_input = actor_model.get_model_input()
with actor_mesh:
actor_model = qwix.apply_lora_to_model(actor_model, lora_provider, rngs=nnx.Rngs(0), **model_input)
max_logging.log("LoRA applied to actor model successfully.")

return reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh


Expand Down
73 changes: 73 additions & 0 deletions tests/post_training/unit/train_rl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,79 @@ def get_dataset_side_effect(model_tokenizer, config, data_dir, split, data_files
self.assertEqual(len(test_batch["prompts"]), 1)
self.assertEqual(test_batch["prompts"][0], "short")

@pytest.mark.cpu_only
def test_create_models_and_meshes_applies_lora_when_enabled(self):
"""When lora.enabled=True, create_models_and_meshes calls qwix.apply_lora_to_model on the actor."""
mock_actor = mock.MagicMock()
mock_actor.get_model_input.return_value = {"input_tokens": "dummy"}
mock_reference = mock.MagicMock()
mock_mesh = mock.MagicMock()
lora_actor = mock.MagicMock(name="lora_actor")

trainer_config = SimpleNamespace(
load_checkpoint_only_once=False,
vllm_additional_config="",
lora=SimpleNamespace(enabled=True, rank=8, alpha=16.0, module_path="layers/.*"),
)
sampler_config = SimpleNamespace(mesh_axes=("data",))

with (
mock.patch(
"maxtext.trainers.post_train.rl.train_rl.get_maxtext_model",
side_effect=[(mock_reference, mock_mesh), (mock_actor, mock_mesh)],
),
mock.patch("maxtext.trainers.post_train.rl.train_rl.maxtext_utils.create_device_mesh", return_value=[]),
mock.patch("maxtext.trainers.post_train.rl.train_rl.Mesh", return_value=mock_mesh),
mock.patch("maxtext.trainers.post_train.rl.train_rl.qwix.LoraProvider") as mock_lora_provider_cls,
mock.patch(
"maxtext.trainers.post_train.rl.train_rl.qwix.apply_lora_to_model", return_value=lora_actor
) as mock_apply,
):
ref, ref_mesh, actor, actor_mesh, rollout_mesh = train_rl.create_models_and_meshes(
trainer_config, sampler_config, ["dev"], ["dev"]
)

# LoraProvider was constructed with the config values
mock_lora_provider_cls.assert_called_once_with(module_path="layers/.*", rank=8, alpha=16.0)
# apply_lora_to_model was called with the actor model and provider
mock_apply.assert_called_once()
call_args = mock_apply.call_args
self.assertIs(call_args[0][0], mock_actor) # first positional arg is the model
self.assertIs(call_args[0][1], mock_lora_provider_cls.return_value) # second is provider
# The returned actor should be the LoRA-wrapped model
self.assertIs(actor, lora_actor)

@pytest.mark.cpu_only
def test_create_models_and_meshes_skips_lora_when_disabled(self):
"""When lora.enabled=False, qwix.apply_lora_to_model is not called."""
mock_actor = mock.MagicMock()
mock_reference = mock.MagicMock()
mock_mesh = mock.MagicMock()

trainer_config = SimpleNamespace(
load_checkpoint_only_once=False,
vllm_additional_config="",
lora=SimpleNamespace(enabled=False, rank=8, alpha=16.0, module_path="layers/.*"),
)
sampler_config = SimpleNamespace(mesh_axes=("data",))

with (
mock.patch(
"maxtext.trainers.post_train.rl.train_rl.get_maxtext_model",
side_effect=[(mock_reference, mock_mesh), (mock_actor, mock_mesh)],
),
mock.patch("maxtext.trainers.post_train.rl.train_rl.maxtext_utils.create_device_mesh", return_value=[]),
mock.patch("maxtext.trainers.post_train.rl.train_rl.Mesh", return_value=mock_mesh),
mock.patch("maxtext.trainers.post_train.rl.train_rl.qwix.apply_lora_to_model") as mock_apply,
):
ref, ref_mesh, actor, actor_mesh, rollout_mesh = train_rl.create_models_and_meshes(
trainer_config, sampler_config, ["dev"], ["dev"]
)

mock_apply.assert_not_called()
# Actor should be the original model, not a LoRA-wrapped one
self.assertIs(actor, mock_actor)


if __name__ == "__main__":
unittest.main()
Loading