Skip to content

feat(training/quickstart): add self-contained training demo#15

Open
dmarx wants to merge 3 commits intomainfrom
dmarx/training-quickstart
Open

feat(training/quickstart): add self-contained training demo#15
dmarx wants to merge 3 commits intomainfrom
dmarx/training-quickstart

Conversation

@dmarx
Copy link
Copy Markdown

@dmarx dmarx commented Aug 6, 2025

This PR publishes an sbatch script which can be used to launch a "self-contained" training demo. The main role of this demo is for debugging and smoke testing. As such, it's sort of an odd fit for the reference-architecture repo because it actually illustrates an anti-pattern: shipping training data packaged into the container. This is useful for the expedient purpose this container is generally used for, but we might want to hold off on merging to main until after adding a reference example for training that doesn't demonstrate the anti-pattern.

Anti-pattern resolved, demo now uses megatron container from ml-containers rather than a special container prepackaged with a tokenizer and data. Also, fixed it so wandb is optional so there's no additional setup besides sunk.

@dmarx dmarx requested a review from tmadhyastha-cw August 6, 2025 20:48
@bradbeam
Copy link
Copy Markdown
Member

I love the concept here. Is there much of a lift needed to make this work for gh200?

@bradbeam
Copy link
Copy Markdown
Member

shipping training data packaged into the container.

Could we make use of a public bucket instead?

* basic megatron container
  * latest torch-extras base
  * public NVIDIA/megatron-lm
* use data mocking
* auto-download tokenizer (built-in)
* wandb optional, automatic if ~/.netrc present
@dmarx dmarx marked this pull request as ready for review March 11, 2026 05:20
@dmarx dmarx requested a review from Eta0 March 11, 2026 16:11
@dmarx dmarx requested a review from sangstar March 31, 2026 16:06
Copy link
Copy Markdown

@sangstar sangstar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure where the line is between "user literally just runs this and it'll work" and "user assumed to make some adjustments for this to work" for this, but there currently are assumptions in this script that will not apply generically for what I think of for a demo in my opinion.

Comments enclosed. Cheers!

Comment on lines +4 to +5
#SBATCH --ntasks-per-node 8
#SBATCH --gpus-per-node 8
Copy link
Copy Markdown

@sangstar sangstar Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not be true for all jobs with GPUs. Some have less than this allocatable.

Also, this is a personal preference thing but I highly prefer the setup --ntasks-per-node=1 --gpus-per-task=8 whenever possible. One task running the sbatch before fanning out in the training script allows you to be able to install things and modify any other filesystem state without running in to race conditions.

The only change required for this to work is to simply wrap the entrypoint script in torchrun, which also has the added benefit of automatically assigning the correct identifying environment variables to each rank, like:

torchrun \
  --nproc_per_node="$SLURM_GPUS_PER_NODE" \
  --nnodes="$SLURM_NNODES" \
  --node_rank=$SLURM_NODEID \
  --master_addr="$MASTER_ADDR" \
  --master_port="$MASTER_PORT" \
  pretrain_gpt.py "${train_args[@]}"

Copy link
Copy Markdown
Author

@dmarx dmarx Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a great call out and suggestion, thanks. I'm also not a huge fan of hardcoding this to 1 node. An alternative could be removing all three of those directives from the top entirely and instead modeling usage to the user with CLI args, i.e.

sbatch launch_training.sbatch \
  --nodes 1 \
  --ntasks-per-node 1 \
  --gpus-per-task=8

thoughts?

Comment on lines +12 to +16
export NCCL_SOCKET_IFNAME=eth0
export SHARP_COLL_ENABLE_PCI_RELAXED_ORDERING=1
export NCCL_COLLNET_ENABLE=0
export NCCL_IB_HCA=ibp
export UCX_NET_DEVICES=ibp0:1,ibp1:1,ibp2:1,ibp3:1,ibp4:1,ibp5:1,ibp6:1,ibp7:1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These environment variables assume very specific things and as such aren't appropriate if this is meant to be trivially ran by users. It assumes:

  • An ethernet NIC for TCP/IP comms (this is probably a totally fine assumption)
  • Tries to run SHARP if possible (as long as this doesn't error out if a node doesn't have SHARP this is also fine)
  • Disables the NCCL CollNet plugin. This shouldn't be unconditionally disabled, and for what it's worth disabling it may not allow SHARP in the first place.
  • Assumes the node running has IB and exactly 8 ports

If this is meant to be pretty well generalizable as a demo script we may as well just remove these environment variables entirely and let NCCL figure these out on its own.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very fair, and also echoes Brad's suggestion that we keep GH200 in mind as well. Considering an environment where these settings would actually be appropriate/optimal, what would be the impact of removing them? Just a little extra setup time as NCCL probes the topology?

A simple path forward could be to just do a couple of different versions of this demo with settings pre-optimized for different topologies. A higher LOE path for the future could be a script that figures out what the environment is and chooses recommended env vars based on what it finds.



# Use squished container if available, otw download and save it.
export REMOTE_IMAGE_URI="ghcr.io#coreweave/ml-containers/megatron:dmarx-megatron-update-0c68584"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last I checked valid GitHub credentials need to be figured here via enroot to download this. Is this documented?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this container is public, does that still require gh creds?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any event, I'll add that to the instructions. got it written up for the megatron demo already anyway.

Comment on lines +61 to +66
export CUDA_DEVICE_MAX_CONNECTIONS=1

export WORLD_SIZE="${SLURM_NTASKS:?}"
export RANK="${SLURM_PROCID:?}"
export LOCAL_RANK="${SLURM_LOCALID:?}"
export CUDA_DEVICE_ORDER='PCI_BUS_ID'
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are unnecessary when using torchrun fwiw. If our container has this I strongly recommend we make this the preferred pattern. torchrun is automatically installed when installing torch, so we should be able to assume it's available.

Comment on lines +102 to +103
--tensor-model-parallel-size 4 \
--pipeline-model-parallel-size 1 \
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assumes the job's available VRAM can support 2 separate model replicas during training. Since we can't know the exact VRAM a generic job will have, setting the tensor parallelism to the world size is probably a good defensive option.

#SBATCH --constraint gpu
#SBATCH --job-name test
#SBATCH --output test.%j
#SBATCH --export all
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably not a good idea to leak in all environment variables from the login node and it ought to not need any of them.

Comment on lines +150 to +153
--num-layers 32 \
--hidden-size 4096 \
--seq-length 8192 \
--max-position-embeddings 8192 \
Copy link
Copy Markdown

@sangstar sangstar Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again makes assumptions about VRAM capacity for jobs, but you need something for these so it's kind of unavoidable that there won't be a magic set of parameters that will always work. Could maybe considering defensively reducing these by a factor of 2 but it's a nitpick.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants