Skip to content

Conversation

@Kuangdd01
Copy link
Collaborator

@Kuangdd01 Kuangdd01 commented Oct 7, 2025

What does this PR do?

Try to introduce mcore by mcore_adapter

Todo List

  • avoid importerror when mcore_adapter is not installed
  • training with moe-type model
  • check performance difference
  • make mca cli entry better [ugly now]
  • overwrite original trainer from mca for recording or callbacks

Quick Start

you can refer this document for detailed info.

Env Setup

📦 pip

# for higher torch version, refer to https://github.com/alibaba/ROLL/blob/main/docker/Dockerfile.torch280
# for megatron-core
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124

pip install \
    numpy==1.26.4 \
    optree>=0.13.0 \
    spacy==3.7.5 \
    weasel==0.4.1 \
    transformer-engine[pytorch]==2.2.0 \
    megatron-core==0.13.0 \
    deepspeed==0.16.4 

pip uninstall -y opencv opencv-python opencv-python-headless
pip install opencv-python-headless==4.11.0.86
pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"

# for llamafactory
git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[torch,metrics]" --no-build-isolation

🐳 docker

We offer a dockerfile here.

our experiment

1. Qwen2VL-7B-instruct training

a800 * 8, dataset: llava_en_1k, cutoff_len: 4096

Train method details speed
deepspeed zero3 example/zero3_config, global_bsz=16 elapsed_time": "0:09:55"
megatron tp=4,pp=2,global_bsz=16 elapsed_time": "0:06:37"

The setting may not be fair, but it's just a correctness check.

loss curve
image

2. Qwen3-30B-A3B-2507-Instruct

a800 * 16, dataset: OpenR1-Math-94k, cutoff_len: 4096

Train method details speed
deepspeed zero3 example/zero3_config, global_bsz=16 54.0s/it (not finished)
megatron pp=4 ep=2,global_bsz=16 4.95s/it

loss curve mcore only
image

3. E2E sft on identity dataset

image

Acknowledgement

Thanks to the ROLL team for this mcore adapter plugin!

Todo

  1. For now, we only support full finetune on this workflow, peft methods not supported.
  2. Add docs and examples
  3. Check DPO

Before submitting

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Kuangdd01, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request, authored by Kuangdd01, introduces initial support for Megatron-LM training via the mcore_adapter library. The changes involve adding a new use_mca flag to enable mcore_adapter specific training workflows for PT, SFT, and DPO stages. This includes using mcore_adapter's TrainingArguments and Trainer classes, forcing torchrun for distributed training, and implementing MCA-specific data handling. A new example configuration file is provided, and conditional model freezing for qwen2_vl models is added. The PR is marked as [WIP] with a clear todo list for future improvements and checks.

Highlights

  • Intent: This pull request introduces initial support for Megatron-LM training using the mcore_adapter library. The goal is to enable distributed training capabilities, including tensor and pipeline parallelism, for various training stages (PT, SFT, DPO).
  • Key Changes: The core changes involve integrating mcore_adapter into the training workflow. A new use_mca argument is added to FinetuningArguments to enable this feature. When use_mca is active, the system is configured to force torchrun for distributed training, and mcore_adapter's TrainingArguments and Trainer classes are used. Dedicated run_pt_mca, run_sft_mca, and run_dpo_mca functions are implemented to handle the specific requirements of mcore_adapter, including adjustments to data collators and dataset cutoff_len due to MCA's shift logic. Conditional model freezing for qwen2_vl models is also introduced for SFT with MCA. A new example configuration file (qwen2_vl_full.yaml) demonstrates how to use mcore_adapter with specific parallelization parameters.
  • New Files: New files include examples/megatron/qwen2_vl_full.yaml (a configuration for MCA training), src/llamafactory/train/mca/__init__.py (MCA module initialization), src/llamafactory/train/mca/trainer.py (a placeholder for future trainer overrides), and src/llamafactory/train/mca/workflow.py (implementing MCA-specific PT, SFT, and DPO workflows).
  • Pending Tasks/WIP: This PR is marked as Work In Progress ([WIP]). The author has outlined several pending tasks, including fixing GPU memory collection during saving, supporting training with MoE-type models, checking performance differences, improving the CLI entry, and overriding the original trainer from MCA for recording or callbacks.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Megatron-LM training using mcore_adapter. The changes are comprehensive, touching argument parsing, training workflows, and configuration. The core logic resides in the new src/llamafactory/train/mca/workflow.py file. While the implementation appears to cover the main requirements, there are several areas for improvement regarding maintainability, clarity, and robustness. My review focuses on addressing code smells like "hacks" and FIXMEs, improving the extensibility of model-specific logic, and making the code safer by avoiding in-place modifications of shared objects. Additionally, there's an inconsistency in how MCA-specific logic is triggered (environment variable vs. config option) that should be resolved for better predictability.

Comment on lines 77 to 81
if is_env_enabled("USE_MCA"):
# force use torchrun
os.environ["FORCE_TORCHRUN"] = "1"

if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of is_env_enabled("USE_MCA") to control logic in both cli.py and parser.py is inconsistent with using finetuning_args.use_mca in tuner.py. This can lead to confusing behavior where setting use_mca: true in a config file doesn't work unless the USE_MCA environment variable is also set. This dependency on an environment variable for a core logic switch makes configuration less transparent and more error-prone.

To make use_mca the single source of truth, consider a two-pass approach for argument parsing. You could first parse just the FinetuningArguments to check the value of use_mca, and then, based on that, select the appropriate full set of argument classes for the second pass. This would eliminate the need for the USE_MCA environment variable.

@Kuangdd01 Kuangdd01 marked this pull request as draft October 7, 2025 17:00
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Megatron-LM training using mcore_adapter. The changes are quite extensive, adding new example configurations, modifying argument parsing, and creating new training workflows. The overall strategy of conditionally using mcore_adapter based on an environment variable is sound. However, I've identified several areas for improvement, particularly concerning code duplication in the new workflow files, and some inconsistencies in argument handling. My review includes specific suggestions to enhance code clarity, consistency, and maintainability.

template = get_template_and_fix_tokenizer(tokenizer, data_args)

# dataset needs +1 then cut back due to MCA shift logic
data_args.cutoff_len += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pattern of incrementing data_args.cutoff_len, calling get_dataset, and then decrementing it is repeated across run_pt, run_sft, and run_dpo. This in-place modification can be error-prone and makes the code less clean. Consider creating a copy of data_args with the modified cutoff_len or using a context manager to handle this temporarily to avoid side effects and reduce repetition.


model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)

from transformers import DataCollatorForSeq2Seq
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import from transformers import DataCollatorForSeq2Seq is located inside the run_pt function. It's generally better practice to place all imports at the top of the file for clarity and to avoid repeated imports. Please move this to the top of the file.

@Kuangdd01 Kuangdd01 marked this pull request as ready for review October 18, 2025 07:41
@Kuangdd01 Kuangdd01 changed the title [WIP][feat] support megatron-LM training by mcore_adapter [feat] support megatron-LM training by mcore_adapter Oct 18, 2025
@Kuangdd01 Kuangdd01 assigned hiyouga and unassigned hiyouga Oct 18, 2025
@Kuangdd01 Kuangdd01 requested a review from hiyouga October 18, 2025 07:42
@Kuangdd01 Kuangdd01 added the pending This problem is yet to be addressed label Oct 18, 2025
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Owner

@hiyouga hiyouga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@hiyouga hiyouga added solved This problem has been already solved and removed pending This problem is yet to be addressed labels Oct 26, 2025
@hiyouga hiyouga merged commit 1317057 into hiyouga:main Oct 26, 2025
16 checks passed
@gemini-code-assist gemini-code-assist bot mentioned this pull request Oct 26, 2025
2 tasks
Jasonqi146 pushed a commit to Jasonqi146/LLaMA-Factory that referenced this pull request Nov 19, 2025
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
penfever pushed a commit to mlfoundations/LLaMA-Factory that referenced this pull request Nov 22, 2025
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
@tic-top
Copy link

tic-top commented Nov 24, 2025

When I try to run

USE_MCA=1 llamafactory-cli train examples/megatron/qwen3_moe_full.yaml

It shows that the mca_config is not exist in model_name_or_path.

[rank7]: Traceback (most recent call last):
[rank7]:   File "/home/LLaMA-Factory/LLaMA-Factory/src/llamafactory/launcher.py", line 185, in <module>
[rank7]:     run_exp()
[rank7]:   File "/home/LLaMA-Factory/LLaMA-Factory/src/llamafactory/train/tuner.py", line 132, in run_exp
[rank7]:     _training_function(config={"args": args, "callbacks": callbacks})
[rank7]:   File "/home/LLaMA-Factory/LLaMA-Factory/src/llamafactory/train/tuner.py", line 79, in _training_function
[rank7]:     run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks)
[rank7]:   File "/home/LLaMA-Factory/LLaMA-Factory/src/llamafactory/train/mca/workflow.py", line 166, in run_sft
[rank7]:     model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
[rank7]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/miniconda3/envs/sft/lib/python3.11/site-packages/mcore_adapter/models/auto/modeling_auto.py", line 57, in from_pretrained
[rank7]:     raise ValueError(f"No valid config found in {model_name_or_path}")
[rank7]: ValueError: No valid config found in Qwen/Qwen3-30B-A3B-Instruct-2507

Do we need to convert hf model to Megatron model before tunning?

@Kuangdd01
Copy link
Collaborator Author

When I try to run

USE_MCA=1 llamafactory-cli train examples/megatron/qwen3_moe_full.yaml

It shows that the mca_config is not exist in model_name_or_path.

[rank7]: Traceback (most recent call last):
[rank7]:   File "/home/LLaMA-Factory/LLaMA-Factory/src/llamafactory/launcher.py", line 185, in <module>
[rank7]:     run_exp()
[rank7]:   File "/home/LLaMA-Factory/LLaMA-Factory/src/llamafactory/train/tuner.py", line 132, in run_exp
[rank7]:     _training_function(config={"args": args, "callbacks": callbacks})
[rank7]:   File "/home/LLaMA-Factory/LLaMA-Factory/src/llamafactory/train/tuner.py", line 79, in _training_function
[rank7]:     run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks)
[rank7]:   File "/home/LLaMA-Factory/LLaMA-Factory/src/llamafactory/train/mca/workflow.py", line 166, in run_sft
[rank7]:     model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
[rank7]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/home/miniconda3/envs/sft/lib/python3.11/site-packages/mcore_adapter/models/auto/modeling_auto.py", line 57, in from_pretrained
[rank7]:     raise ValueError(f"No valid config found in {model_name_or_path}")
[rank7]: ValueError: No valid config found in Qwen/Qwen3-30B-A3B-Instruct-2507

Do we need to convert hf model to Megatron model before tunning?

I guess it is a network error. Try to load model locally.

@1277331747
Copy link

model_name_or_path: /data/ptmodels/Qwen/Qwen3-Next-80B-A3B-Instruct

do_train: true
stage: sft
finetuning_type: full # only support full for now
dataset: A,B,C,D,E,F,G,H
preprocessing_num_workers: 8
cutoff_len: 1024
template: qwen3_nothink

output_dir: /data/sftmodels/FinMDT-ThoughtPOP/llma-Qwen-next-sft-dot1
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
num_train_epochs: 1
learning_rate: 3e-6
logging_steps: 1
save_steps: 1
lr_scheduler_type: constant
bf16: true

tensor_model_parallel_size: 1
sequence_parallel: true
pipeline_model_parallel_size: 4
bias_activation_fusion: true
apply_rope_fusion: true
use_distributed_optimizer: true
overlap_param_gather: false
overlap_grad_reduce: false
moe_grouped_gemm: true
moe_token_dispatcher_type: alltoall
expert_model_parallel_size: 2
recompute_granularity: full

sh

cd /data/code/llama-mcore/LLaMA-Factory &&
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 &&
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True &&
USE_MCA=1 llamafactory-cli train /data/code/llama-mcore/LLaMA-Factory/examples/megatron/qwen3_moe_full.yaml

为什么我这样配置8张H20 也会出现

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 5 has a total capacity of 95.08 GiB of which 17.88 MiB is free. Including non-PyTorch memory, this process has 95.06 GiB memory in use. Of the allocated memory 91.95 GiB is allocated by PyTorch, and 12.15 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

@Kuangdd01
Copy link
Collaborator Author

For Qwen3-Next(80B), we may need 16*80GB GPU memory at least for model parameters amd optimizer loading if we do use normal training recipes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

solved This problem has been already solved

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants