mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-26 20:17:45 +08:00
[v1] add callbacks (#10255)
This commit is contained in:
@@ -85,6 +85,10 @@ class TrainingArguments:
|
|||||||
default=42,
|
default=42,
|
||||||
metadata={"help": "Random seed that will be set at the beginning of training."},
|
metadata={"help": "Random seed that will be set at the beginning of training."},
|
||||||
)
|
)
|
||||||
|
logging_steps: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Log metrics every N optimizer steps."},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self.dist_config = get_plugin_config(self.dist_config)
|
self.dist_config = get_plugin_config(self.dist_config)
|
||||||
|
|||||||
@@ -36,6 +36,12 @@ from ..accelerator.helper import ReduceOp
|
|||||||
from ..accelerator.interface import Dim, DistributedInterface
|
from ..accelerator.interface import Dim, DistributedInterface
|
||||||
from ..config import TrainingArguments
|
from ..config import TrainingArguments
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
|
from ..utils.callbacks import (
|
||||||
|
CallbackHandler,
|
||||||
|
LoggingCallback,
|
||||||
|
TrainerCallback,
|
||||||
|
TrainerState,
|
||||||
|
)
|
||||||
from ..utils.helper import compute_valid_tokens
|
from ..utils.helper import compute_valid_tokens
|
||||||
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
|
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
|
||||||
from .utils.batching import BatchGenerator
|
from .utils.batching import BatchGenerator
|
||||||
@@ -52,6 +58,7 @@ class BaseTrainer:
|
|||||||
model: HFModel,
|
model: HFModel,
|
||||||
renderer: Renderer,
|
renderer: Renderer,
|
||||||
train_dataset: TorchDataset,
|
train_dataset: TorchDataset,
|
||||||
|
callbacks: list[TrainerCallback] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.args = args
|
self.args = args
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -99,6 +106,14 @@ class BaseTrainer:
|
|||||||
self._init_optimizer()
|
self._init_optimizer()
|
||||||
self._init_lr_scheduler()
|
self._init_lr_scheduler()
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
self.callback_handler = CallbackHandler([LoggingCallback()], trainer=self)
|
||||||
|
for cb in callbacks or []:
|
||||||
|
self.callback_handler.add_callback(cb)
|
||||||
|
|
||||||
|
# Callbacks: TrainerState tracks progress across the full run.
|
||||||
|
self.state = TrainerState(num_training_steps=self.num_training_steps)
|
||||||
|
|
||||||
def _create_batch_generator(self) -> None:
|
def _create_batch_generator(self) -> None:
|
||||||
self.train_batch_generator = BatchGenerator(
|
self.train_batch_generator = BatchGenerator(
|
||||||
dataset=self.train_dataset,
|
dataset=self.train_dataset,
|
||||||
@@ -174,10 +189,18 @@ class BaseTrainer:
|
|||||||
def fit(self) -> None:
|
def fit(self) -> None:
|
||||||
"""Train the model."""
|
"""Train the model."""
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
self.callback_handler.on_train_begin(self.args, self.state)
|
||||||
for epoch in range(self.args.num_train_epochs):
|
for epoch in range(self.args.num_train_epochs):
|
||||||
|
self.state.epoch = epoch
|
||||||
self.train_batch_generator.set_epoch(epoch)
|
self.train_batch_generator.set_epoch(epoch)
|
||||||
|
self.callback_handler.on_epoch_begin(self.args, self.state)
|
||||||
|
|
||||||
for micro_batches in self.train_batch_generator:
|
for micro_batches in self.train_batch_generator:
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
|
||||||
|
self.state.global_step = self.global_step
|
||||||
|
self.callback_handler.on_step_begin(self.args, self.state)
|
||||||
|
|
||||||
step_loss = 0
|
step_loss = 0
|
||||||
step_valid_tokens = compute_valid_tokens(micro_batches)
|
step_valid_tokens = compute_valid_tokens(micro_batches)
|
||||||
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
||||||
@@ -213,14 +236,41 @@ class BaseTrainer:
|
|||||||
|
|
||||||
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
||||||
DistributedInterface().sync()
|
DistributedInterface().sync()
|
||||||
if DistributedInterface().get_rank() == 0:
|
|
||||||
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
|
# Update state with step metrics
|
||||||
|
current_lr = (
|
||||||
|
self.lr_scheduler.get_last_lr()[0]
|
||||||
|
if hasattr(self.lr_scheduler, "get_last_lr")
|
||||||
|
else self.args.learning_rate
|
||||||
|
)
|
||||||
|
self.state.loss = step_loss
|
||||||
|
self.state.grad_norm = grad_norm
|
||||||
|
self.state.learning_rate = current_lr
|
||||||
|
|
||||||
|
self.callback_handler.on_step_end(self.args, self.state)
|
||||||
|
|
||||||
|
# Logging: trainer decides when to log
|
||||||
|
if self.global_step % self.args.logging_steps == 0:
|
||||||
|
logs = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"step": self.global_step,
|
||||||
|
"loss": step_loss,
|
||||||
|
"grad_norm": grad_norm,
|
||||||
|
"learning_rate": current_lr,
|
||||||
|
}
|
||||||
|
self.callback_handler.on_log(self.args, self.state, logs)
|
||||||
|
|
||||||
# Check if max_steps is reached
|
# Check if max_steps is reached
|
||||||
if self.global_step >= self.num_training_steps:
|
if self.global_step >= self.num_training_steps:
|
||||||
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
|
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
|
||||||
|
self.callback_handler.on_epoch_end(self.args, self.state)
|
||||||
|
self.callback_handler.on_train_end(self.args, self.state)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self.callback_handler.on_epoch_end(self.args, self.state)
|
||||||
|
|
||||||
|
self.callback_handler.on_train_end(self.args, self.state)
|
||||||
|
|
||||||
def save_model(self) -> None:
|
def save_model(self) -> None:
|
||||||
"""Save the model."""
|
"""Save the model."""
|
||||||
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
|
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
|
||||||
@@ -234,3 +284,5 @@ class BaseTrainer:
|
|||||||
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||||
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||||
|
|
||||||
|
self.callback_handler.on_save(self.args, self.state)
|
||||||
|
|||||||
24
src/llamafactory/v1/utils/callbacks/__init__.py
Normal file
24
src/llamafactory/v1/utils/callbacks/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .logging_callback import LoggingCallback
|
||||||
|
from .trainer_callback import CallbackHandler, TrainerCallback, TrainerState
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CallbackHandler",
|
||||||
|
"LoggingCallback",
|
||||||
|
"TrainerCallback",
|
||||||
|
"TrainerState",
|
||||||
|
]
|
||||||
64
src/llamafactory/v1/utils/callbacks/logging_callback.py
Normal file
64
src/llamafactory/v1/utils/callbacks/logging_callback.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from .. import logging
|
||||||
|
from .trainer_callback import TrainerCallback, TrainerState
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...config import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingCallback(TrainerCallback):
|
||||||
|
"""Logs training metrics to stdout on rank-0 and appends to ``state.log_history``.
|
||||||
|
|
||||||
|
On each logging step the entry is also persisted as a JSON line in
|
||||||
|
``<output_dir>/trainer_log.jsonl`` so that training history survives crashes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_log(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
logs: dict[str, Any],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
# Persist in history regardless of rank
|
||||||
|
state.log_history.append(dict(logs))
|
||||||
|
|
||||||
|
# Everything below is rank-0 only
|
||||||
|
from ...accelerator.interface import DistributedInterface # lazy import
|
||||||
|
|
||||||
|
if DistributedInterface().get_rank() != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Human-readable output to stdout
|
||||||
|
display_logs = {**logs, "total_steps": state.num_training_steps}
|
||||||
|
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
|
||||||
|
logger.info_rank0(parts)
|
||||||
|
|
||||||
|
# Append to JSONL log file in output_dir
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
log_file = os.path.join(args.output_dir, "trainer_log.jsonl")
|
||||||
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(display_logs, ensure_ascii=False) + "\n")
|
||||||
147
src/llamafactory/v1/utils/callbacks/trainer_callback.py
Normal file
147
src/llamafactory/v1/utils/callbacks/trainer_callback.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...config import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainerState:
|
||||||
|
"""A read-only snapshot of training progress passed to every callback hook.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
epoch: Current epoch (0-indexed).
|
||||||
|
global_step: Number of optimizer steps completed so far.
|
||||||
|
num_training_steps: Total number of optimizer steps planned.
|
||||||
|
loss: Scalar loss value of the most recent step.
|
||||||
|
grad_norm: Gradient-norm value of the most recent step.
|
||||||
|
learning_rate: Current learning rate seen by the optimizer.
|
||||||
|
log_history: List of per-step log dicts emitted by ``LoggingCallback``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
epoch: int = 0
|
||||||
|
global_step: int = 0
|
||||||
|
num_training_steps: int = 0
|
||||||
|
loss: float = 0.0
|
||||||
|
grad_norm: float = 0.0
|
||||||
|
learning_rate: float = 0.0
|
||||||
|
log_history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerCallback:
|
||||||
|
"""Abstract base class for training callbacks.
|
||||||
|
|
||||||
|
Subclass and override whichever hooks you need. All hooks receive:
|
||||||
|
|
||||||
|
- ``args`` – the :class:`~llamafactory.v1.config.TrainingArguments`.
|
||||||
|
- ``state`` – a :class:`TrainerState` snapshot (read-only).
|
||||||
|
- ``**kwargs`` – extra keyword arguments (model, optimizer, …).
|
||||||
|
|
||||||
|
Callbacks are *observers*: they should NOT mutate training flow.
|
||||||
|
|
||||||
|
Hook call order::
|
||||||
|
|
||||||
|
on_train_begin
|
||||||
|
for each epoch:
|
||||||
|
on_epoch_begin
|
||||||
|
for each step:
|
||||||
|
on_step_begin
|
||||||
|
(forward / backward / optimizer.step)
|
||||||
|
on_step_end
|
||||||
|
[on_log] ← if this step is a logging step
|
||||||
|
on_epoch_end
|
||||||
|
on_train_end
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_train_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
"""Called once before the first training step."""
|
||||||
|
|
||||||
|
def on_train_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
"""Called once after the last training step."""
|
||||||
|
|
||||||
|
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
"""Called at the beginning of each epoch."""
|
||||||
|
|
||||||
|
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
"""Called at the end of each epoch."""
|
||||||
|
|
||||||
|
def on_step_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
"""Called before the forward/backward pass of each optimizer step."""
|
||||||
|
|
||||||
|
def on_step_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
"""Called after the optimizer step."""
|
||||||
|
|
||||||
|
def on_log(self, args: TrainingArguments, state: TrainerState, logs: dict[str, Any], **kwargs: Any) -> None:
|
||||||
|
"""Called when the trainer emits a log entry."""
|
||||||
|
|
||||||
|
def on_save(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
"""Called after the model checkpoint has been written to disk."""
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackHandler:
|
||||||
|
"""Owns a list of :class:`TrainerCallback` instances and fans out hook calls.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
handler = CallbackHandler([LoggingCallback(), MyWandbCallback()], trainer=trainer)
|
||||||
|
handler.on_train_begin(args, state)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, callbacks: list[TrainerCallback] | None = None, trainer: Any = None) -> None:
|
||||||
|
self.callbacks: list[TrainerCallback] = list(callbacks or [])
|
||||||
|
self.trainer = trainer
|
||||||
|
|
||||||
|
def add_callback(self, callback: TrainerCallback) -> None:
|
||||||
|
"""Append a callback to the handler."""
|
||||||
|
self.callbacks.append(callback)
|
||||||
|
|
||||||
|
def _call(self, event: str, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
if self.trainer is not None:
|
||||||
|
kwargs.setdefault("model", getattr(self.trainer, "model", None))
|
||||||
|
kwargs.setdefault("optimizer", getattr(self.trainer, "optimizer", None))
|
||||||
|
kwargs.setdefault("lr_scheduler", getattr(self.trainer, "lr_scheduler", None))
|
||||||
|
kwargs.setdefault("train_dataloader", getattr(self.trainer, "train_batch_generator", None))
|
||||||
|
|
||||||
|
for cb in self.callbacks:
|
||||||
|
getattr(cb, event)(args, state, **kwargs)
|
||||||
|
|
||||||
|
def on_train_begin(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||||
|
self._call("on_train_begin", args, state)
|
||||||
|
|
||||||
|
def on_train_end(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||||
|
self._call("on_train_end", args, state)
|
||||||
|
|
||||||
|
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||||
|
self._call("on_epoch_begin", args, state)
|
||||||
|
|
||||||
|
def on_epoch_end(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||||
|
self._call("on_epoch_end", args, state)
|
||||||
|
|
||||||
|
def on_step_begin(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||||
|
self._call("on_step_begin", args, state)
|
||||||
|
|
||||||
|
def on_step_end(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||||
|
self._call("on_step_end", args, state)
|
||||||
|
|
||||||
|
def on_log(self, args: TrainingArguments, state: TrainerState, logs: dict[str, Any]) -> None:
|
||||||
|
self._call("on_log", args, state, logs=logs)
|
||||||
|
|
||||||
|
def on_save(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||||
|
self._call("on_save", args, state)
|
||||||
Reference in New Issue
Block a user