[v1] add callbacks (#10255)

This commit is contained in:
jiaqiw09
2026-03-26 19:59:57 +08:00
committed by GitHub
parent 1e536733c6
commit c340aa2a33
5 changed files with 293 additions and 2 deletions

View File

@@ -85,6 +85,10 @@ class TrainingArguments:
default=42,
metadata={"help": "Random seed that will be set at the beginning of training."},
)
logging_steps: int = field(
default=1,
metadata={"help": "Log metrics every N optimizer steps."},
)
def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config)

View File

@@ -36,6 +36,12 @@ from ..accelerator.helper import ReduceOp
from ..accelerator.interface import Dim, DistributedInterface
from ..config import TrainingArguments
from ..utils import logging
from ..utils.callbacks import (
CallbackHandler,
LoggingCallback,
TrainerCallback,
TrainerState,
)
from ..utils.helper import compute_valid_tokens
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
from .utils.batching import BatchGenerator
@@ -52,6 +58,7 @@ class BaseTrainer:
model: HFModel,
renderer: Renderer,
train_dataset: TorchDataset,
callbacks: list[TrainerCallback] | None = None,
) -> None:
self.args = args
self.model = model
@@ -99,6 +106,14 @@ class BaseTrainer:
self._init_optimizer()
self._init_lr_scheduler()
# Callbacks
self.callback_handler = CallbackHandler([LoggingCallback()], trainer=self)
for cb in callbacks or []:
self.callback_handler.add_callback(cb)
# Callbacks: TrainerState tracks progress across the full run.
self.state = TrainerState(num_training_steps=self.num_training_steps)
def _create_batch_generator(self) -> None:
self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset,
@@ -174,10 +189,18 @@ class BaseTrainer:
def fit(self) -> None:
"""Train the model."""
self.model.train()
self.callback_handler.on_train_begin(self.args, self.state)
for epoch in range(self.args.num_train_epochs):
self.state.epoch = epoch
self.train_batch_generator.set_epoch(epoch)
self.callback_handler.on_epoch_begin(self.args, self.state)
for micro_batches in self.train_batch_generator:
self.global_step += 1
self.state.global_step = self.global_step
self.callback_handler.on_step_begin(self.args, self.state)
step_loss = 0
step_valid_tokens = compute_valid_tokens(micro_batches)
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
@@ -213,14 +236,41 @@ class BaseTrainer:
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync()
if DistributedInterface().get_rank() == 0:
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
# Update state with step metrics
current_lr = (
self.lr_scheduler.get_last_lr()[0]
if hasattr(self.lr_scheduler, "get_last_lr")
else self.args.learning_rate
)
self.state.loss = step_loss
self.state.grad_norm = grad_norm
self.state.learning_rate = current_lr
self.callback_handler.on_step_end(self.args, self.state)
# Logging: trainer decides when to log
if self.global_step % self.args.logging_steps == 0:
logs = {
"epoch": epoch,
"step": self.global_step,
"loss": step_loss,
"grad_norm": grad_norm,
"learning_rate": current_lr,
}
self.callback_handler.on_log(self.args, self.state, logs)
# Check if max_steps is reached
if self.global_step >= self.num_training_steps:
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
self.callback_handler.on_epoch_end(self.args, self.state)
self.callback_handler.on_train_end(self.args, self.state)
return
self.callback_handler.on_epoch_end(self.args, self.state)
self.callback_handler.on_train_end(self.args, self.state)
def save_model(self) -> None:
"""Save the model."""
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
@@ -234,3 +284,5 @@ class BaseTrainer:
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
logger.info_rank0(f"Model saved to {self.args.output_dir}")
self.callback_handler.on_save(self.args, self.state)

View File

@@ -0,0 +1,24 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .logging_callback import LoggingCallback
from .trainer_callback import CallbackHandler, TrainerCallback, TrainerState
__all__ = [
"CallbackHandler",
"LoggingCallback",
"TrainerCallback",
"TrainerState",
]

View File

@@ -0,0 +1,64 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import os
from typing import TYPE_CHECKING, Any
from .. import logging
from .trainer_callback import TrainerCallback, TrainerState
if TYPE_CHECKING:
from ...config import TrainingArguments
logger = logging.get_logger(__name__)
class LoggingCallback(TrainerCallback):
"""Logs training metrics to stdout on rank-0 and appends to ``state.log_history``.
On each logging step the entry is also persisted as a JSON line in
``<output_dir>/trainer_log.jsonl`` so that training history survives crashes.
"""
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
logs: dict[str, Any],
**kwargs: Any,
) -> None:
# Persist in history regardless of rank
state.log_history.append(dict(logs))
# Everything below is rank-0 only
from ...accelerator.interface import DistributedInterface # lazy import
if DistributedInterface().get_rank() != 0:
return
# Human-readable output to stdout
display_logs = {**logs, "total_steps": state.num_training_steps}
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
logger.info_rank0(parts)
# Append to JSONL log file in output_dir
os.makedirs(args.output_dir, exist_ok=True)
log_file = os.path.join(args.output_dir, "trainer_log.jsonl")
with open(log_file, "a", encoding="utf-8") as f:
f.write(json.dumps(display_logs, ensure_ascii=False) + "\n")

View File

@@ -0,0 +1,147 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from ...config import TrainingArguments
@dataclass
class TrainerState:
"""A read-only snapshot of training progress passed to every callback hook.
Attributes:
epoch: Current epoch (0-indexed).
global_step: Number of optimizer steps completed so far.
num_training_steps: Total number of optimizer steps planned.
loss: Scalar loss value of the most recent step.
grad_norm: Gradient-norm value of the most recent step.
learning_rate: Current learning rate seen by the optimizer.
log_history: List of per-step log dicts emitted by ``LoggingCallback``.
"""
epoch: int = 0
global_step: int = 0
num_training_steps: int = 0
loss: float = 0.0
grad_norm: float = 0.0
learning_rate: float = 0.0
log_history: list[dict[str, Any]] = field(default_factory=list)
class TrainerCallback:
"""Abstract base class for training callbacks.
Subclass and override whichever hooks you need. All hooks receive:
- ``args`` the :class:`~llamafactory.v1.config.TrainingArguments`.
- ``state`` a :class:`TrainerState` snapshot (read-only).
- ``**kwargs`` extra keyword arguments (model, optimizer, …).
Callbacks are *observers*: they should NOT mutate training flow.
Hook call order::
on_train_begin
for each epoch:
on_epoch_begin
for each step:
on_step_begin
(forward / backward / optimizer.step)
on_step_end
[on_log] ← if this step is a logging step
on_epoch_end
on_train_end
"""
def on_train_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called once before the first training step."""
def on_train_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called once after the last training step."""
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called at the beginning of each epoch."""
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called at the end of each epoch."""
def on_step_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called before the forward/backward pass of each optimizer step."""
def on_step_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called after the optimizer step."""
def on_log(self, args: TrainingArguments, state: TrainerState, logs: dict[str, Any], **kwargs: Any) -> None:
"""Called when the trainer emits a log entry."""
def on_save(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called after the model checkpoint has been written to disk."""
class CallbackHandler:
"""Owns a list of :class:`TrainerCallback` instances and fans out hook calls.
Usage::
handler = CallbackHandler([LoggingCallback(), MyWandbCallback()], trainer=trainer)
handler.on_train_begin(args, state)
"""
def __init__(self, callbacks: list[TrainerCallback] | None = None, trainer: Any = None) -> None:
self.callbacks: list[TrainerCallback] = list(callbacks or [])
self.trainer = trainer
def add_callback(self, callback: TrainerCallback) -> None:
"""Append a callback to the handler."""
self.callbacks.append(callback)
def _call(self, event: str, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
if self.trainer is not None:
kwargs.setdefault("model", getattr(self.trainer, "model", None))
kwargs.setdefault("optimizer", getattr(self.trainer, "optimizer", None))
kwargs.setdefault("lr_scheduler", getattr(self.trainer, "lr_scheduler", None))
kwargs.setdefault("train_dataloader", getattr(self.trainer, "train_batch_generator", None))
for cb in self.callbacks:
getattr(cb, event)(args, state, **kwargs)
def on_train_begin(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_train_begin", args, state)
def on_train_end(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_train_end", args, state)
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_epoch_begin", args, state)
def on_epoch_end(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_epoch_end", args, state)
def on_step_begin(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_step_begin", args, state)
def on_step_end(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_step_end", args, state)
def on_log(self, args: TrainingArguments, state: TrainerState, logs: dict[str, Any]) -> None:
self._call("on_log", args, state, logs=logs)
def on_save(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_save", args, state)