From 97433c53b6523acb96a71ac14bd0e1ccf27ccc1b Mon Sep 17 00:00:00 2001 From: Cui-yshoho <73014084+Cui-yshoho@users.noreply.github.com> Date: Mon, 30 Mar 2026 10:47:20 +0800 Subject: [PATCH] [feat] support LlamaFactory SFT training by HyperParallel FSDP2 backend (#10289) --- src/llamafactory/extras/packages.py | 4 + src/llamafactory/hparams/finetuning_args.py | 18 ++ .../train/hyper_parallel/__init__.py | 18 ++ .../train/hyper_parallel/workflow.py | 179 ++++++++++++++++++ src/llamafactory/train/tuner.py | 18 +- 5 files changed, 235 insertions(+), 2 deletions(-) create mode 100644 src/llamafactory/train/hyper_parallel/__init__.py create mode 100644 src/llamafactory/train/hyper_parallel/workflow.py diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index c6328a7b0..eb373d091 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -70,6 +70,10 @@ def is_matplotlib_available(): return _is_package_available("matplotlib") +def is_hyper_parallel_available(): + return _is_package_available("hyper_parallel") + + def is_mcore_adapter_available(): return _is_package_available("mcore_adapter") diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 053b4ab6a..2a1ecc943 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -482,6 +482,24 @@ class FinetuningArguments( ) }, ) + use_hyper_parallel: bool = field( + default=False, + metadata={ + "help": ( + "Whether or not to use HyperParallel distributed training backend (FSDP/TP). " + "Only supported for the 'sft' stage with full fine-tuning." + ) + }, + ) + hyper_parallel_args: str | None = field( + default=None, + metadata={ + "help": ( + "Path to a JSON file containing HyperParallel strategy arguments " + "(e.g., tp_size, param_dtype). Used when use_hyper_parallel=True." + ) + }, + ) use_muon: bool = field( default=False, metadata={"help": "Whether or not to use the Muon optimizer."}, diff --git a/src/llamafactory/train/hyper_parallel/__init__.py b/src/llamafactory/train/hyper_parallel/__init__.py new file mode 100644 index 000000000..6107a9ae7 --- /dev/null +++ b/src/llamafactory/train/hyper_parallel/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .workflow import run_sft + + +__all__ = ["run_sft"] diff --git a/src/llamafactory/train/hyper_parallel/workflow.py b/src/llamafactory/train/hyper_parallel/workflow.py new file mode 100644 index 000000000..dd63901d8 --- /dev/null +++ b/src/llamafactory/train/hyper_parallel/workflow.py @@ -0,0 +1,179 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Optional + +from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer +from ...extras.constants import IGNORE_INDEX +from ...extras.logging import get_logger +from ...extras.misc import calculate_tps +from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than +from ...extras.ploting import plot_loss +from ...model import load_model, load_tokenizer +from ..callbacks import SaveProcessorCallback +from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor +from ..trainer_utils import asft_loss_func, create_modelcard_and_push, create_ref_model, dft_loss_func, eaft_loss_func + + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + + from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +logger = get_logger(__name__) + + +def run_sft( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: Optional[list["TrainerCallback"]] = None, +): + if not is_hyper_parallel_available(): + raise ImportError( + "hyper_parallel is not installed. Please install it with `pip install hyper_parallel`." + ) + + from hyper_parallel.integration.llamafactory import HyperParallelArguments, HyperParallelTrainer # pylint: disable=C0415 + + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + template = get_template_and_fix_tokenizer(tokenizer, data_args) + dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module) + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + + ref_model = None + if finetuning_args.use_asft_loss: + ref_model = create_ref_model(model_args, finetuning_args) + + data_collator = SFTDataCollatorWith4DAttentionMask( + template=template, + model=model if not training_args.predict_with_generate else None, + pad_to_multiple_of=8 if training_args.do_train else None, + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + block_diag_attn=model_args.block_diag_attn, + attn_implementation=getattr(model.config, "_attn_implementation", None), + compute_dtype=model_args.compute_dtype, + **tokenizer_module, + ) + + # Metric utils + metric_module = {} + if training_args.predict_with_generate: + metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer) + elif finetuning_args.compute_accuracy: + metric_module["compute_metrics"] = ComputeAccuracy() + metric_module["preprocess_logits_for_metrics"] = eval_logit_processor + + # Keyword arguments for `model.generate` + gen_kwargs = generating_args.to_dict(obey_generation_config=True) + if is_transformers_version_greater_than("4.58.0"): + extra_ids = getattr(tokenizer, "additional_special_tokens_ids", None) + if not isinstance(extra_ids, list): + extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", []) + string_tokens = [str(t) for t in extra_special_tokens] + extra_ids = tokenizer.convert_tokens_to_ids(string_tokens) + all_eos_ids = [tokenizer.eos_token_id] + [i for i in extra_ids if i != -1] + gen_kwargs["eos_token_id"] = list(dict.fromkeys(all_eos_ids)) + else: + gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids + gen_kwargs["pad_token_id"] = tokenizer.pad_token_id + + hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args) + + callbacks = list(callbacks or []) + processor = tokenizer_module.get("processor") + if processor is not None: + callbacks.append(SaveProcessorCallback(processor)) + + compute_loss_func = None + if finetuning_args.use_dft_loss: + compute_loss_func = dft_loss_func + elif finetuning_args.use_eaft_loss: + compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( # noqa: E731 + outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha + ) + elif finetuning_args.use_asft_loss: + from functools import partial + + compute_loss_func = partial(asft_loss_func, asft_alpha=finetuning_args.asft_alpha) + + trainer = HyperParallelTrainer( + hp_args=hp_args, + model=model, + args=training_args, + finetuning_args=finetuning_args, + data_collator=data_collator, + callbacks=callbacks, + gen_kwargs=gen_kwargs, + ref_model=ref_model, + compute_loss_func=compute_loss_func, + **dataset_module, + **tokenizer_module, + **metric_module, + ) + + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import] + from types import MethodType + + trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator) + trainer.add_callback(BAdamCallback) + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + if finetuning_args.include_effective_tokens_per_second: + train_result.metrics["effective_tokens_per_sec"] = calculate_tps( + dataset_module["train_dataset"], train_result.metrics, stage="sft" + ) + + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + keys = ["loss"] + if isinstance(dataset_module.get("eval_dataset"), dict): + keys += sum( + [[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], + [], + ) + else: + keys += ["eval_loss", "eval_accuracy"] + + plot_loss(training_args.output_dir, keys=keys) + + if training_args.predict_with_generate: + tokenizer.padding_side = "left" + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Predict + if training_args.do_predict: + logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.") + predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs) + trainer.log_metrics("predict", predict_results.metrics) + trainer.save_metrics("predict", predict_results.metrics) + trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens) + + # Create model card + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 38ddb90dc..411ed3ac7 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -24,7 +24,12 @@ from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype -from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than +from ..extras.packages import ( + is_hyper_parallel_available, + is_mcore_adapter_available, + is_ray_available, + is_transformers_version_greater_than, +) from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback @@ -71,7 +76,16 @@ def _training_function(config: dict[str, Any]) -> None: callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last - if finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca: + if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel: + if not is_hyper_parallel_available(): + raise ImportError( + "hyper_parallel is not installed. Please install it with `pip install hyper_parallel`." + ) + from .hyper_parallel import run_sft as run_sft_hp + + run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) + + elif finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca: if not is_mcore_adapter_available(): raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.") if finetuning_args.stage == "pt":