Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
e6b7cab
feat(config): Added OmegaConf based serializer save_yaml_config_dict().
BlueCrescent Nov 27, 2025
9fa51ec
feat(huggingface): Added conversion of distributed gpt2 checkpoints t…
BlueCrescent Nov 27, 2025
a73de85
chore: Merge branch 'fix_rotary_transform_deferred_init' into hf_chec…
BlueCrescent Nov 28, 2025
d7d0956
refactor: More robust parent directory path handling.
BlueCrescent Nov 28, 2025
8957f19
docs: better dcp to torch conversion docstring
BlueCrescent Nov 28, 2025
527a0d2
fix: Added handling for missing directory.
BlueCrescent Nov 28, 2025
95cead4
fix: use Path instead of string
BlueCrescent Nov 28, 2025
b8cf4ea
fix: use cpu device for dcp to torch converted checkpoints
BlueCrescent Nov 28, 2025
652e77a
fix: error handling if wrong model key is set in checkpoint conversion
BlueCrescent Nov 28, 2025
fca72dc
feat(utility): Moved MultiProcessingCudaEnv from tests to modalities.
BlueCrescent Dec 2, 2025
ace93c7
feat(utility): Added option to set init_process_group kwargs in cuda …
BlueCrescent Dec 3, 2025
53eb907
feat(utility): Extended get_model_from_config for distributed checkpo…
BlueCrescent Dec 3, 2025
3a4b46c
feat(huggingface): Added dcp specific conversion verification logic.
BlueCrescent Dec 3, 2025
642466d
fix(huggingface): Better dcp config conversion.
BlueCrescent Dec 3, 2025
f54abc6
feat(config): Added interoperability between PyTorchDtypes and Precis…
BlueCrescent Dec 3, 2025
3fbe498
fix(huggingface): Correct conversion of model dtype.
BlueCrescent Dec 3, 2025
ee4e244
fix(config): circular import
BlueCrescent Dec 3, 2025
1b4cfe0
feat(checkpointing): improvements for dcp to torch checkpoint conversion
BlueCrescent Dec 5, 2025
3a67ed9
revert(config): Removed PrecisionEnum <-> PyTorchDtypes interoperabil…
BlueCrescent Dec 5, 2025
ddbb8cc
fix(huggingface): output parity between dcp and converted hf checkpoints
BlueCrescent Dec 5, 2025
5a36d48
fix(model): Corrected type casting in rotary pos embeddings to match …
BlueCrescent Dec 5, 2025
bce2ae1
feat(utility): Added weights printing to print_forward_hook.
BlueCrescent Dec 8, 2025
5da0e7f
fix(requirements): Excluded bugged transformers versions.
BlueCrescent Dec 8, 2025
f902152
feat(utility): Added EnvOverride utility for temporary changing envir…
BlueCrescent Dec 9, 2025
d520095
fix(huggingface): Setting some environment variables when loading dcp…
BlueCrescent Dec 9, 2025
42a7e42
fix(checkpointing): Moved EnvOverride into load_dcp_config so that al…
BlueCrescent Dec 11, 2025
03e07f5
fix(huggingface): Made single node dcp config generation more robust …
BlueCrescent Dec 11, 2025
8a9ff2f
test(utility): Made manager shutdown in monitor_child_processes optio…
BlueCrescent Dec 11, 2025
9ae218d
test(huggingface): Added unit tests for dcp to hf conversion.
BlueCrescent Dec 11, 2025
36e2e25
fix(huggingface): For now, Huggingface version 5.0.0 is not tested.
BlueCrescent Feb 8, 2026
cfbe7df
chore: Merge remote-tracking branch 'origin/main' into hf_checkpoint_…
BlueCrescent Feb 26, 2026
1db0a6c
fix(huggingface): first fixes for conversion tests after main merge
BlueCrescent Feb 26, 2026
993d4ff
fix(huggingface): for comparison fsdp2 mixed precision <-> huggingfac…
BlueCrescent Feb 27, 2026
2e3076d
chore: Merge remote-tracking branch 'origin/main' into hf_checkpoint_…
BlueCrescent Mar 17, 2026
bc1ca36
fix(huggingface): added missing experiments_root_path when loading ch…
BlueCrescent Mar 20, 2026
1b2aca5
feat(huggingface): Added conversion of GPT2 models using pytorch rms …
BlueCrescent Mar 24, 2026
436bf95
chore: Merge remote-tracking branch 'origin/main' into hf_checkpoint_…
BlueCrescent Apr 4, 2026
3be2921
feat: moved find_free_port from test to src to use it in dcp to huggi…
BlueCrescent Apr 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions src/modalities/checkpointing/convert_dcp_to_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
from pathlib import Path
from typing import Any

import torch
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.filesystem import FileSystemReader
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict

from modalities.config.config import load_app_config_dict, save_yaml_config_dict


def convert_dcp_to_torch(dcp_checkpoint_dir: str, output_dir: str, model_key: str = "model_raw") -> str:
"""Converts a FSDP2 checkpoint to a standard PyTorch checkpoint.

Args:
dcp_checkpoint_dir (str): Directory containing the FSDP2 checkpoint files.
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
output_dir (str): Directory to save the converted PyTorch checkpoint.
model_key (str): Key of the model configuration in the modalities config.
Returns:
str: Path to the converted config file.
"""
os.makedirs(output_dir, exist_ok=True)
torch_checkpoint_file = os.path.join(output_dir, "pytorch_model.bin")
torch_config_file = convert_config_file(dcp_checkpoint_dir, output_dir, model_key, torch_checkpoint_file)
# TODO This is the (adapted) code from torch's dcp_to_torch_save(dcp_checkpoint_dir, torch_checkpoint_file)
# since we only want to convert the model state dict here. In future torch versions this function might
# support converting only parts of the checkpoint.
# (from torch.distributed.checkpoint.format_utils import dcp_to_torch_save)
sd: STATE_DICT_TYPE = {}
_load_state_dict(
sd, storage_reader=FileSystemReader(dcp_checkpoint_dir), planner=_EmptyStateDictLoadPlanner(), no_dist=True
)
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
torch.save(sd["app"]["model"], torch_checkpoint_file)
return torch_config_file


def convert_config_file(dcp_checkpoint_dir: str, output_dir: str, model_key: str, torch_checkpoint_file: str) -> str:
"""Converts the modalities config file for DCP to a config file for standard PyTorch checkpoint loading.
Args:
dcp_checkpoint_dir (str): Directory containing the DCP checkpoint files.
output_dir (str): Directory to save the converted config file.
model_key (str): Key of the model configuration in the modalities config.
torch_checkpoint_file (str): Path to the converted PyTorch checkpoint file.
Returns:
str: Path to the converted config file.
"""
config_src: str | None = find_yaml_config_in_dir(dcp_checkpoint_dir)
if config_src is None:
config_src = find_yaml_config_in_dir(os.path.join(dcp_checkpoint_dir, ".."))
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
if config_src is None:
raise FileNotFoundError("No YAML config file found in checkpoint directory or its parent.")
config_dst: str = os.path.join(output_dir, os.path.basename(config_src))
dcp_config: dict[str, Any] = load_app_config_dict(Path(config_src), experiment_id="-1")
torch_config: dict[str, Any] = {
"checkpointed_model": {
"component_key": "model",
"variant_key": "fsdp1_checkpointed",
"config": {
"checkpoint_loading": {
"component_key": "checkpoint_loading",
"variant_key": "torch",
"config": {
"device": 0,
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
"precision": "BF16", # FIXME Should this be configurable?
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
},
},
"model": {
"instance_key": "model",
"pass_type": "BY_REFERENCE",
},
"checkpoint_path": torch_checkpoint_file,
},
},
}
Comment thread
BlueCrescent marked this conversation as resolved.
torch_config["model"] = dcp_config[model_key]
Comment thread
BlueCrescent marked this conversation as resolved.
torch_config["model"]["config"]["use_meta_device"] = False
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessing torch_config["model"]["config"]["use_meta_device"] without validation could raise KeyError or TypeError if the config structure from dcp_config[model_key] doesn't contain these nested keys. Consider adding validation or using .get() with proper error handling to provide a clear error message about the expected config structure.

Suggested change
torch_config["model"]["config"]["use_meta_device"] = False
model_section = torch_config.get("model")
if not isinstance(model_section, dict):
raise TypeError(
f"Expected 'model' section in config file '{config_src}' to be a mapping, "
f"but got {type(model_section).__name__!r}."
)
model_config = model_section.get("config")
if not isinstance(model_config, dict):
raise TypeError(
f"Expected 'model.config' section in config file '{config_src}' to be a mapping, "
f"but got {type(model_config).__name__!r}."
)
model_config["use_meta_device"] = False

Copilot uses AI. Check for mistakes.
save_yaml_config_dict(torch_config, config_dst)
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
return config_dst


def find_yaml_config_in_dir(directory: str) -> str | None:
"""Finds the first YAML config file in the given directory.

Args:
directory (str): Directory to search for YAML files.

Returns:
str | None: Path to the found YAML file or None if not found.
"""
Comment thread
BlueCrescent marked this conversation as resolved.
for filename in os.listdir(directory):
if filename.endswith(".yaml") or filename.endswith(".yml"):
return os.path.join(directory, filename)
return None
19 changes: 16 additions & 3 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,13 +497,14 @@ class ParallelDegreeConfig(BaseModel):
# Recursive type representing arbitrary-depth YAML config structures.
YAMLPrimitive = str | int | float | bool | None
YAMLValue: TypeAlias = YAMLPrimitive | Path | list["YAMLValue"] | dict[str, "YAMLValue"]
ConfigDictType: TypeAlias = dict[str, YAMLValue]


def load_app_config_dict(
config_file_path: Path,
experiment_id: Optional[str] = None,
additional_resolver_funs: Optional[dict[str, Resolver]] = None,
) -> dict[str, YAMLValue]:
) -> ConfigDictType:
"""Load the application configuration from the given YAML file.

Args:
Expand All @@ -512,7 +513,7 @@ def load_app_config_dict(
additional_resolver_funs (dict[str, Resolver], optional): Additional resolver functions.

Returns:
dict[str, YAMLValue]: Dictionary representation of the config file with arbitrary depth.
ConfigDictType: Dictionary representation of the config file with arbitrary depth.
"""

def cuda_env_resolver_fun(var_name: str) -> int | str | None:
Expand All @@ -528,6 +529,7 @@ def modalities_env_resolver_fun(var_name: str, kwargs: dict[str, Any]) -> str |
def node_env_resolver_fun(var_name: str) -> int | None:
if var_name == "num_cpus":
return os.cpu_count()
return None

OmegaConf.register_new_resolver("cuda_env", cuda_env_resolver_fun, replace=True)
modalities_env_kwargs: dict[str, Any] = {
Expand All @@ -546,6 +548,17 @@ def node_env_resolver_fun(var_name: str) -> int | None:
OmegaConf.register_new_resolver(resolver_name, resolver_fun, replace=True)

cfg = OmegaConf.load(config_file_path)
config_dict = cast(dict[str, YAMLValue], OmegaConf.to_container(cfg, resolve=True))
config_dict = cast(ConfigDictType, OmegaConf.to_container(cfg, resolve=True))

return config_dict


def save_yaml_config_dict(config_dict: ConfigDictType, output_file_path: Path) -> None:
"""Saves the given config dictionary as a YAML file.

Args:
config_dict (ConfigDictType): Configuration dictionary to save.
output_file_path (Path): Path to the output YAML file.
"""
cfg = OmegaConf.create(config_dict)
OmegaConf.save(cfg, output_file_path)
18 changes: 10 additions & 8 deletions src/modalities/conversion/gpt2/conversion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from tqdm import tqdm

from modalities.config.config import ConfigDictType
from modalities.conversion.gpt2.configuration_gpt2 import GPT2Config
from modalities.conversion.gpt2.modeling_gpt2 import GPT2DecoderLayer, GPT2ForCausalLM
from modalities.models.components.layer_norms import LayerNormConfig
Expand All @@ -10,13 +11,13 @@
from modalities.models.utils import ModelTypeEnum, get_model_from_config


def convert_model_checkpoint(modalities_config: dict) -> tuple[GPT2ForCausalLM, GPT2LLM]:
def convert_model_checkpoint(modalities_config: ConfigDictType) -> tuple[GPT2ForCausalLM, GPT2LLM]:
"""Converts the modalities model to a Huggingface transformers model.
Both the loaded modalities model and the converted Huggingface model are returned
so that they can be compared.

Args:
modalities_config (dict): Modalities config dictionary.
modalities_config (ConfigDictType): Modalities config dictionary.

Returns:
tuple[GPT2ForCausalLM, GPT2LLM]: Converted Hugging Face model and the original modalities model.
Expand All @@ -28,13 +29,13 @@ def convert_model_checkpoint(modalities_config: dict) -> tuple[GPT2ForCausalLM,
return hf_model, modalities_model


def convert_model_config(modalities_config: dict) -> GPT2Config:
def convert_model_config(modalities_config: ConfigDictType) -> GPT2Config:
"""Converts the modalities model configuration to a Huggingface transformers configuration.
For this the model_raw or model section of the modalities config is used.
Corresponding entries are mapped to the Huggingface configuration.

Args:
modalities_config (dict): Modalities config dictionary.
modalities_config (ConfigDictType): Modalities config dictionary.

Returns:
GPT2Config: Converted Huggingface model configuration.
Expand Down Expand Up @@ -85,14 +86,15 @@ def check_converted_model(hf_model: GPT2ForCausalLM, modalities_model: GPT2LLM,
modalities_logits = modalities_model(inputs)[modalities_model.prediction_key].to("cpu")

assert llama_logits.shape == modalities_logits.shape
assert llama_logits.dtype == modalities_logits.dtype
assert torch.equal(llama_logits, modalities_logits)


def _check_conversion_criteria(model_config: dict) -> None:
def _check_conversion_criteria(model_config: ConfigDictType) -> None:
"""Checks that the modalities config fulfills criteria necessary for conversion

Args:
model_config (dict): model or model_raw part of the Modalities config dictionary.
model_config (ConfigDictType): model or model_raw part of the Modalities config dictionary.

Returns:
None
Expand All @@ -116,12 +118,12 @@ def _check_conversion_criteria(model_config: dict) -> None:
), "All norms must have the same eps setting."


def _get_layer_norm_value(config: dict, field: str) -> bool | float | int:
def _get_layer_norm_value(config: ConfigDictType, field: str) -> bool | float | int:
default = LayerNormConfig.model_fields[field].default
return config.get(field, default)


def _map_attention_type(config: dict):
def _map_attention_type(config: ConfigDictType) -> str:
Comment thread
BlueCrescent marked this conversation as resolved.
if config["attention_implementation"] == "pytorch_flash":
attention_impl = "sdpa"
elif config["attention_implementation"] == "manual":
Expand Down
108 changes: 83 additions & 25 deletions src/modalities/conversion/gpt2/convert_gpt2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""
usage: convert_gpt2.py [-h] [--num_testruns NUM_TESTRUNS] [--device_modalities DEVICE_MODALITIES]
[--device_hf DEVICE_HF] modalities_config output_dir
[--device_hf DEVICE_HF] [--dcp] [--model_key MODEL_KEY]
modalities_input output_dir

Convert GPT-2 model checkpoint to Huggingface transformers format.

positional arguments:
modalities_config Path to the modalities config file.
modalities_input Path to the modalities config file or the dcp checkpoint dir.
output_dir Directory to save the converted model.

options:
Expand All @@ -16,13 +17,18 @@
Device for the modalities model.
--device_hf DEVICE_HF
Device for the Hugging Face model.
--dcp Indicates that the provided modalities checkpoint is in DCP format.
--model_key MODEL_KEY
Key of the model configuration in the modalities config.
"""

import argparse
import logging
import os
from pathlib import Path
from tempfile import TemporaryDirectory

from modalities.checkpointing.convert_dcp_to_torch import convert_dcp_to_torch
from modalities.config.config import load_app_config_dict
from modalities.conversion.gpt2.conversion_code import transfer_model_code
from modalities.conversion.gpt2.conversion_model import check_converted_model, convert_model_checkpoint
Expand All @@ -31,6 +37,71 @@
logger = logging.getLogger(__name__)


def main():
_ensure_logging()

os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"

parser = argparse.ArgumentParser(description="Convert GPT-2 model checkpoint to Huggingface transformers format.")
parser.add_argument(
"modalities_input", type=str, help="Path to the modalities config file or the dcp checkpoint dir."
)
parser.add_argument("output_dir", type=str, help="Directory to save the converted model.")
parser.add_argument("--num_testruns", type=int, default=0, help="Number of test runs to perform.")
parser.add_argument("--device_modalities", type=str, default="cpu", help="Device for the modalities model.")
parser.add_argument("--device_hf", type=str, default="cpu", help="Device for the Hugging Face model.")
parser.add_argument(
"--dcp", action="store_true", help="Indicates that the provided modalities checkpoint is in DCP format."
)
parser.add_argument(
"--model_key", type=str, default="model_raw", help="Key of the model configuration in the modalities config."
)

args = parser.parse_args()

logger.info("Starting GPT-2 conversion script...")
if args.dcp:
convert_gpt2_dcp(
args.modalities_input,
args.output_dir,
args.num_testruns,
args.device_modalities,
args.device_hf,
args.model_key,
)
else:
convert_gpt2(
args.modalities_input,
args.output_dir,
args.num_testruns,
args.device_modalities,
args.device_hf,
)


def convert_gpt2_dcp(
distributed_cp_dir: str,
output_dir: str,
num_testruns: int = 0,
device_modalities: str = "cpu",
device_hf: str = "cpu",
model_key: str = "model_raw",
) -> None:
with TemporaryDirectory() as temp_dir:
logger.info("Converting DCP checkpoint to standard PyTorch checkpoint...")
modalities_config_path = convert_dcp_to_torch(distributed_cp_dir, temp_dir, model_key=model_key)
logger.info("Converting standard PyTorch checkpoint to Huggingface transformers format...")
convert_gpt2(
modalities_config_path,
output_dir,
num_testruns,
device_modalities,
device_hf,
)


def convert_gpt2(
modalities_config_path: str,
output_dir: str,
Expand Down Expand Up @@ -77,10 +148,6 @@ def convert_gpt2(
elif len(sentence_piece_tokenizer_configs) == 1:
tokenizer_model = modalities_config["tokenizer"]["config"]["tokenizer_model_file"]
bos_token_id, eos_token_id, pad_token_id, _ = convert_tokenizer(tokenizer_model, output_dir)
# The values bos=1, eos=2 and pad=None are set by default in the model config (as taken from Llama).
# Overwrite them, with the actual values from the internal SentencePiece tokenizer.
# Note, that the LlamaTokenizer wrapping around the SentencePiece tokenizer does not know about these values.
# The unk token id is not set in the model config.
hf_model.config.bos_token_id = bos_token_id
hf_model.config.eos_token_id = eos_token_id
hf_model.config.pad_token_id = pad_token_id
Expand All @@ -95,24 +162,15 @@ def convert_gpt2(
transfer_model_code(output_dir)


if __name__ == "__main__":
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = "0"

parser = argparse.ArgumentParser(description="Convert GPT-2 model checkpoint to Huggingface transformers format.")
parser.add_argument("modalities_config", type=str, help="Path to the modalities config file.")
parser.add_argument("output_dir", type=str, help="Directory to save the converted model.")
parser.add_argument("--num_testruns", type=int, default=0, help="Number of test runs to perform.")
parser.add_argument("--device_modalities", type=str, default="cpu", help="Device for the modalities model.")
parser.add_argument("--device_hf", type=str, default="cpu", help="Device for the Hugging Face model.")
def _ensure_logging():
if not logger.hasHandlers():
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(name)s | %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

args = parser.parse_args()

convert_gpt2(
args.modalities_config,
args.output_dir,
args.num_testruns,
args.device_modalities,
args.device_hf,
)
if __name__ == "__main__":
main()
5 changes: 3 additions & 2 deletions src/modalities/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import BaseModel

from modalities.config.component_factory import ComponentFactory
from modalities.config.config import ConfigDictType
from modalities.config.pydantic_if_types import PydanticPytorchModuleType
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry
Expand All @@ -21,12 +22,12 @@ class ModelTypeEnum(Enum):
CHECKPOINTED_MODEL = "checkpointed_model"


def get_model_from_config(config: dict, model_type: ModelTypeEnum):
def get_model_from_config(config: ConfigDictType, model_type: ModelTypeEnum):
"""
Retrieves a model from the given configuration based on the specified model type.

Args:
config (dict): The configuration dictionary.
config (ConfigDictType): The configuration dictionary.
model_type (ModelTypeEnum): The type of the model to retrieve.

Returns:
Expand Down