From 29cc9bc8beb4450b318aab779e7b4ed20d00321b Mon Sep 17 00:00:00 2001 From: Andy Twigg Date: Wed, 18 Mar 2026 22:10:40 +0000 Subject: [PATCH 1/3] first commit of lora integration (wip) --- src/maxtext/configs/post_train/rl.yml | 12 ++++++++++++ src/maxtext/configs/types.py | 13 +++++++++++++ src/maxtext/layers/nnx_wrappers.py | 9 ++++++--- src/maxtext/trainers/post_train/rl/train_rl.py | 18 +++++++++++++++++- 4 files changed, 48 insertions(+), 4 deletions(-) diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index da455a13e2..830cc66c55 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -32,6 +32,18 @@ rollout_expert_parallelism: 1 # ====== Reproducibility ====== data_shuffle_seed: 42 +# ====== LoRA ====== +# Low-Rank Adaptation for the actor model. When enabled, only the LoRA parameters +# are trained and checkpointed, significantly reducing memory and compute. +lora: + enabled: False + rank: 32 + alpha: 64.0 + # Regex matching module paths to apply LoRA to. + # Qwix uses re.fullmatch over slash-separated module paths like: + # layers/self_attention/query, layers/self_attention/key, ... + module_path: 'layers/(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))' + # ====== RL ====== # This config includes RL algorithm variations such as grpo or gspo-token rl: diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 3df51ac106..c4fa423450 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1622,6 +1622,15 @@ class VLLM(BaseModel): ) vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.") + +class LoraConfig(BaseModel): + """Configuration for LoRA (Low-Rank Adaptation) applied to the actor model during RL training.""" + + enabled: bool = Field(False, description="If True, apply LoRA to the actor model instead of full-weight training.") + rank: int = Field(4, description="Rank of the LoRA decomposition.") + alpha: float = Field(8.0, description="Alpha scaling parameter for LoRA (effective LR scale = alpha / rank).") + module_path: str = Field(".*", description="Regex matching which module paths to apply LoRA to.") + class RL(BaseModel): """Configuration for RL algorithms like Group Relative Policy Optimization (GRPO) among others.""" @@ -1948,6 +1957,10 @@ class MaxTextConfig( default_factory=RL, description="Configuration for RL algorithms like Group Relative Policy Optimization (GRPO).", ) + lora: LoraConfig = Field( + default_factory=LoraConfig, + description="Configuration for LoRA applied to the actor model during RL training.", + ) model_config = ConfigDict(extra="forbid", protected_namespaces=()) @model_validator(mode="before") diff --git a/src/maxtext/layers/nnx_wrappers.py b/src/maxtext/layers/nnx_wrappers.py index fe41af9b40..bacd4e66a1 100644 --- a/src/maxtext/layers/nnx_wrappers.py +++ b/src/maxtext/layers/nnx_wrappers.py @@ -253,9 +253,12 @@ def __call__( out, updates = self.to_nnx__module.init_with_output(_rngs, *args, method=method, **kwargs) else: nnx_attrs = { - k: v - for k, v in vars(self).items() - if not k.startswith("to_nnx__") and not k.startswith("_pytree__") and not k.startswith("_object__") + k: v + for k, v in vars(self).items() + if not k.startswith("to_nnx__") + and not k.startswith("_pytree__") + and not k.startswith("_object__") + and isinstance(v, dict) # linen variable collections are always dicts; skip plain attrs (e.g. qwix metadata) } variables = nnx_attrs_to_linen_vars(nnx_attrs) diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 7f5a33eed6..f30cf3ff60 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -63,6 +63,7 @@ from orbax import checkpoint as ocp from pprint import pprint from transformers import AutoTokenizer +import qwix from tunix.rl import rl_cluster as rl_cluster_lib from tunix.rl.rollout import base_rollout from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner @@ -79,7 +80,6 @@ from maxtext.input_pipeline.instruction_data_processing import load_template_from_file from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils - def get_maxtext_model(config, devices=None): """ Load MaxText model with Tunix adapter. @@ -419,6 +419,22 @@ def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sa max_logging.log("Creating policy model with same config as reference model on trainer mesh") actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices) + # add lora adapter to actor model + if trainer_config.lora.enabled: + max_logging.log( + f"Applying LoRA to actor model: rank={trainer_config.lora.rank}, " + f"alpha={trainer_config.lora.alpha}, module_path='{trainer_config.lora.module_path}'" + ) + lora_provider = qwix.LoraProvider( + module_path=trainer_config.lora.module_path, + rank=trainer_config.lora.rank, + alpha=trainer_config.lora.alpha, + ) + model_input = actor_model.get_model_input() + with actor_mesh: + actor_model = qwix.apply_lora_to_model(actor_model, lora_provider, rngs=nnx.Rngs(0), **model_input) + max_logging.log("LoRA applied to actor model successfully.") + return reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh From 329058d008964eef05df0bcf5056d246ed73ea0d Mon Sep 17 00:00:00 2001 From: Andy Twigg Date: Wed, 18 Mar 2026 22:11:38 +0000 Subject: [PATCH 2/3] add tunix_adapter to lora --- src/maxtext/integration/tunix/tunix_adapter.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/maxtext/integration/tunix/tunix_adapter.py b/src/maxtext/integration/tunix/tunix_adapter.py index d509e512a1..65b36d652c 100644 --- a/src/maxtext/integration/tunix/tunix_adapter.py +++ b/src/maxtext/integration/tunix/tunix_adapter.py @@ -93,3 +93,16 @@ def lora_to_hf_mappings(self): return {} return self._vllm_weight_mapping.lora_to_hf_mappings() + + def get_model_input(self): + """Returns dummy inputs matching __call__ for qwix.apply_lora_to_model tracing.""" + import jax.numpy as jnp # pylint: disable=import-outside-toplevel + + dummy_batch_size = 1 + dummy_seq_len = 128 + return { + "input_tokens": jnp.ones((dummy_batch_size, dummy_seq_len), dtype=jnp.int32), + "positions": jnp.ones((dummy_batch_size, dummy_seq_len), dtype=jnp.int32), + "cache": None, + "attention_mask": None, + } \ No newline at end of file From d09e928e37906233a61e643db74abf3b257fe7c7 Mon Sep 17 00:00:00 2001 From: Andy Twigg Date: Wed, 18 Mar 2026 22:16:39 +0000 Subject: [PATCH 3/3] add lora unit test --- tests/post_training/unit/train_rl_test.py | 73 +++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py index 4bb9831f60..40f95ef8b3 100644 --- a/tests/post_training/unit/train_rl_test.py +++ b/tests/post_training/unit/train_rl_test.py @@ -364,6 +364,79 @@ def get_dataset_side_effect(model_tokenizer, config, data_dir, split, data_files self.assertEqual(len(test_batch["prompts"]), 1) self.assertEqual(test_batch["prompts"][0], "short") + @pytest.mark.cpu_only + def test_create_models_and_meshes_applies_lora_when_enabled(self): + """When lora.enabled=True, create_models_and_meshes calls qwix.apply_lora_to_model on the actor.""" + mock_actor = mock.MagicMock() + mock_actor.get_model_input.return_value = {"input_tokens": "dummy"} + mock_reference = mock.MagicMock() + mock_mesh = mock.MagicMock() + lora_actor = mock.MagicMock(name="lora_actor") + + trainer_config = SimpleNamespace( + load_checkpoint_only_once=False, + vllm_additional_config="", + lora=SimpleNamespace(enabled=True, rank=8, alpha=16.0, module_path="layers/.*"), + ) + sampler_config = SimpleNamespace(mesh_axes=("data",)) + + with ( + mock.patch( + "maxtext.trainers.post_train.rl.train_rl.get_maxtext_model", + side_effect=[(mock_reference, mock_mesh), (mock_actor, mock_mesh)], + ), + mock.patch("maxtext.trainers.post_train.rl.train_rl.maxtext_utils.create_device_mesh", return_value=[]), + mock.patch("maxtext.trainers.post_train.rl.train_rl.Mesh", return_value=mock_mesh), + mock.patch("maxtext.trainers.post_train.rl.train_rl.qwix.LoraProvider") as mock_lora_provider_cls, + mock.patch( + "maxtext.trainers.post_train.rl.train_rl.qwix.apply_lora_to_model", return_value=lora_actor + ) as mock_apply, + ): + ref, ref_mesh, actor, actor_mesh, rollout_mesh = train_rl.create_models_and_meshes( + trainer_config, sampler_config, ["dev"], ["dev"] + ) + + # LoraProvider was constructed with the config values + mock_lora_provider_cls.assert_called_once_with(module_path="layers/.*", rank=8, alpha=16.0) + # apply_lora_to_model was called with the actor model and provider + mock_apply.assert_called_once() + call_args = mock_apply.call_args + self.assertIs(call_args[0][0], mock_actor) # first positional arg is the model + self.assertIs(call_args[0][1], mock_lora_provider_cls.return_value) # second is provider + # The returned actor should be the LoRA-wrapped model + self.assertIs(actor, lora_actor) + + @pytest.mark.cpu_only + def test_create_models_and_meshes_skips_lora_when_disabled(self): + """When lora.enabled=False, qwix.apply_lora_to_model is not called.""" + mock_actor = mock.MagicMock() + mock_reference = mock.MagicMock() + mock_mesh = mock.MagicMock() + + trainer_config = SimpleNamespace( + load_checkpoint_only_once=False, + vllm_additional_config="", + lora=SimpleNamespace(enabled=False, rank=8, alpha=16.0, module_path="layers/.*"), + ) + sampler_config = SimpleNamespace(mesh_axes=("data",)) + + with ( + mock.patch( + "maxtext.trainers.post_train.rl.train_rl.get_maxtext_model", + side_effect=[(mock_reference, mock_mesh), (mock_actor, mock_mesh)], + ), + mock.patch("maxtext.trainers.post_train.rl.train_rl.maxtext_utils.create_device_mesh", return_value=[]), + mock.patch("maxtext.trainers.post_train.rl.train_rl.Mesh", return_value=mock_mesh), + mock.patch("maxtext.trainers.post_train.rl.train_rl.qwix.apply_lora_to_model") as mock_apply, + ): + ref, ref_mesh, actor, actor_mesh, rollout_mesh = train_rl.create_models_and_meshes( + trainer_config, sampler_config, ["dev"], ["dev"] + ) + + mock_apply.assert_not_called() + # Actor should be the original model, not a LoRA-wrapped one + self.assertIs(actor, mock_actor) + if __name__ == "__main__": unittest.main()