Former-commit-id: b4790c66c126567bd193de52a564e3ce11c94769
This commit is contained in:
fzc8578
2025-01-06 19:32:39 +08:00
parent 08729dbefc
commit 8c2a712247
4 changed files with 15 additions and 7 deletions

View File

@@ -24,6 +24,7 @@ import numpy as np
import torch
from transformers import Seq2SeqTrainer
from typing_extensions import override
import copy
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
@@ -122,7 +123,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
labels = inputs.pop("labels", None)
else:
labels = inputs.get("labels")
loss, generated_tokens, _ = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs
)