[algo] add ASFT (#10174)

This commit is contained in:
Junyou Su
2026-02-12 13:12:14 +08:00
committed by GitHub
parent ab073f4c13
commit 675ce8cc7f
6 changed files with 228 additions and 2 deletions

View File

@@ -24,7 +24,7 @@ from ...extras.misc import calculate_tps
from ...extras.packages import is_transformers_version_greater_than
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..trainer_utils import create_modelcard_and_push
from ..trainer_utils import create_modelcard_and_push, create_ref_model
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from .trainer import CustomSeq2SeqTrainer
@@ -52,6 +52,10 @@ def run_sft(
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
ref_model = None
if finetuning_args.use_asft_loss:
ref_model = create_ref_model(model_args, finetuning_args)
if getattr(model, "is_quantized", False) and not training_args.do_train:
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
@@ -124,6 +128,7 @@ def run_sft(
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
ref_model=ref_model,
**dataset_module,
**tokenizer_module,
**metric_module,