mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-23 18:53:23 +08:00
alter rewards data type
Former-commit-id: 3eb7eb2d37525da50fe401ab7c59532e6e1ef984
This commit is contained in:
@@ -109,7 +109,8 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
if hasattr(model, "v_head"): # save valuehead weights
|
||||
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||
f.write(self.args.to_json_string() + "\n")
|
||||
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
|
||||
|
||||
def _load_best_model(self):
|
||||
|
||||
Reference in New Issue
Block a user