mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-25 19:47:44 +08:00
[algo] add ASFT (#10174)
This commit is contained in:
@@ -24,7 +24,7 @@ from ...extras.misc import calculate_tps
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..trainer_utils import create_modelcard_and_push
|
||||
from ..trainer_utils import create_modelcard_and_push, create_ref_model
|
||||
from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
||||
from .trainer import CustomSeq2SeqTrainer
|
||||
|
||||
@@ -52,6 +52,10 @@ def run_sft(
|
||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
|
||||
ref_model = None
|
||||
if finetuning_args.use_asft_loss:
|
||||
ref_model = create_ref_model(model_args, finetuning_args)
|
||||
|
||||
if getattr(model, "is_quantized", False) and not training_args.do_train:
|
||||
setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
|
||||
|
||||
@@ -124,6 +128,7 @@ def run_sft(
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
gen_kwargs=gen_kwargs,
|
||||
ref_model=ref_model,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**metric_module,
|
||||
|
||||
Reference in New Issue
Block a user