alter rewards data type

Former-commit-id: 3eb7eb2d37525da50fe401ab7c59532e6e1ef984
This commit is contained in:
hiyouga
2023-06-02 14:19:51 +08:00
parent 896dbfec16
commit e9ab06678f
12 changed files with 40 additions and 50 deletions

View File

@@ -109,7 +109,8 @@ class PeftTrainer(Seq2SeqTrainer):
if hasattr(model, "v_head"): # save valuehead weights
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n")
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
def _load_best_model(self):