fix generating args

Former-commit-id: 52805a8441bd7b324bd89489de60f18f103c8e4c
This commit is contained in:
hiyouga
2023-06-13 01:33:56 +08:00
parent 4724ae3492
commit 6828f07d54
5 changed files with 20 additions and 16 deletions

View File

@@ -30,8 +30,8 @@ def main():
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \
training_args.generation_max_length is not None else data_args.max_target_length
training_args.generation_num_beams = data_args.num_beams if \
data_args.num_beams is not None else training_args.generation_num_beams
training_args.generation_num_beams = data_args.eval_num_beams if \
data_args.eval_num_beams is not None else training_args.generation_num_beams
# Split the dataset
if training_args.do_train: