refactor evaluation, upgrade trl to 074

Former-commit-id: ed09ebe2c1926ffdb0520b3866f7fd03a9aed046
This commit is contained in:
hiyouga
2023-11-13 22:20:35 +08:00
parent 989eccd286
commit 64fc9ba678
21 changed files with 341 additions and 247 deletions

View File

@@ -226,7 +226,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
replace_model(unwrapped_model, target="default")
return rewards
@PPODecorators.empty_cuda_cache()
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: "AutoModelForCausalLMWithValueHead",

View File

@@ -42,7 +42,7 @@ def run_ppo(
ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_cuda_cache=True,
optimize_device_cache=True,
target=finetuning_args.ppo_target,
log_with=finetuning_args.ppo_logger,
use_score_scaling=finetuning_args.ppo_score_norm,