mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-23 18:53:23 +08:00
refactor evaluation, upgrade trl to 074
Former-commit-id: ed09ebe2c1926ffdb0520b3866f7fd03a9aed046
This commit is contained in:
@@ -226,7 +226,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
replace_model(unwrapped_model, target="default")
|
||||
return rewards
|
||||
|
||||
@PPODecorators.empty_cuda_cache()
|
||||
@PPODecorators.empty_device_cache()
|
||||
def batched_forward_pass(
|
||||
self,
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
|
||||
@@ -42,7 +42,7 @@ def run_ppo(
|
||||
ppo_epochs=1,
|
||||
max_grad_norm=training_args.max_grad_norm,
|
||||
seed=training_args.seed,
|
||||
optimize_cuda_cache=True,
|
||||
optimize_device_cache=True,
|
||||
target=finetuning_args.ppo_target,
|
||||
log_with=finetuning_args.ppo_logger,
|
||||
use_score_scaling=finetuning_args.ppo_score_norm,
|
||||
|
||||
Reference in New Issue
Block a user