mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-22 18:03:23 +08:00
support InternLM
Former-commit-id: a454ef7d57d9c06302d51464cfe39f6d0c48c5a8
This commit is contained in:
@@ -104,8 +104,8 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
||||
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
|
||||
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
|
||||
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
|
||||
Reference in New Issue
Block a user