support InternLM

Former-commit-id: a454ef7d57d9c06302d51464cfe39f6d0c48c5a8
This commit is contained in:
hiyouga
2023-07-07 11:02:28 +08:00
parent 601b1747d1
commit 113cdaf1cb
3 changed files with 15 additions and 2 deletions

View File

@@ -104,8 +104,8 @@ class Seq2SeqPeftTrainer(PeftTrainer):
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)
with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []