[v1] add callbacks (#10255)

This commit is contained in:
jiaqiw09
2026-03-26 19:59:57 +08:00
committed by GitHub
parent 1e536733c6
commit c340aa2a33
5 changed files with 293 additions and 2 deletions

View File

@@ -85,6 +85,10 @@ class TrainingArguments:
default=42, default=42,
metadata={"help": "Random seed that will be set at the beginning of training."}, metadata={"help": "Random seed that will be set at the beginning of training."},
) )
logging_steps: int = field(
default=1,
metadata={"help": "Log metrics every N optimizer steps."},
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config) self.dist_config = get_plugin_config(self.dist_config)

View File

@@ -36,6 +36,12 @@ from ..accelerator.helper import ReduceOp
from ..accelerator.interface import Dim, DistributedInterface from ..accelerator.interface import Dim, DistributedInterface
from ..config import TrainingArguments from ..config import TrainingArguments
from ..utils import logging from ..utils import logging
from ..utils.callbacks import (
CallbackHandler,
LoggingCallback,
TrainerCallback,
TrainerState,
)
from ..utils.helper import compute_valid_tokens from ..utils.helper import compute_valid_tokens
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
from .utils.batching import BatchGenerator from .utils.batching import BatchGenerator
@@ -52,6 +58,7 @@ class BaseTrainer:
model: HFModel, model: HFModel,
renderer: Renderer, renderer: Renderer,
train_dataset: TorchDataset, train_dataset: TorchDataset,
callbacks: list[TrainerCallback] | None = None,
) -> None: ) -> None:
self.args = args self.args = args
self.model = model self.model = model
@@ -99,6 +106,14 @@ class BaseTrainer:
self._init_optimizer() self._init_optimizer()
self._init_lr_scheduler() self._init_lr_scheduler()
# Callbacks
self.callback_handler = CallbackHandler([LoggingCallback()], trainer=self)
for cb in callbacks or []:
self.callback_handler.add_callback(cb)
# Callbacks: TrainerState tracks progress across the full run.
self.state = TrainerState(num_training_steps=self.num_training_steps)
def _create_batch_generator(self) -> None: def _create_batch_generator(self) -> None:
self.train_batch_generator = BatchGenerator( self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset, dataset=self.train_dataset,
@@ -174,10 +189,18 @@ class BaseTrainer:
def fit(self) -> None: def fit(self) -> None:
"""Train the model.""" """Train the model."""
self.model.train() self.model.train()
self.callback_handler.on_train_begin(self.args, self.state)
for epoch in range(self.args.num_train_epochs): for epoch in range(self.args.num_train_epochs):
self.state.epoch = epoch
self.train_batch_generator.set_epoch(epoch) self.train_batch_generator.set_epoch(epoch)
self.callback_handler.on_epoch_begin(self.args, self.state)
for micro_batches in self.train_batch_generator: for micro_batches in self.train_batch_generator:
self.global_step += 1 self.global_step += 1
self.state.global_step = self.global_step
self.callback_handler.on_step_begin(self.args, self.state)
step_loss = 0 step_loss = 0
step_valid_tokens = compute_valid_tokens(micro_batches) step_valid_tokens = compute_valid_tokens(micro_batches)
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM) step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
@@ -213,14 +236,41 @@ class BaseTrainer:
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm]) step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync() DistributedInterface().sync()
if DistributedInterface().get_rank() == 0:
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}") # Update state with step metrics
current_lr = (
self.lr_scheduler.get_last_lr()[0]
if hasattr(self.lr_scheduler, "get_last_lr")
else self.args.learning_rate
)
self.state.loss = step_loss
self.state.grad_norm = grad_norm
self.state.learning_rate = current_lr
self.callback_handler.on_step_end(self.args, self.state)
# Logging: trainer decides when to log
if self.global_step % self.args.logging_steps == 0:
logs = {
"epoch": epoch,
"step": self.global_step,
"loss": step_loss,
"grad_norm": grad_norm,
"learning_rate": current_lr,
}
self.callback_handler.on_log(self.args, self.state, logs)
# Check if max_steps is reached # Check if max_steps is reached
if self.global_step >= self.num_training_steps: if self.global_step >= self.num_training_steps:
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.") logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
self.callback_handler.on_epoch_end(self.args, self.state)
self.callback_handler.on_train_end(self.args, self.state)
return return
self.callback_handler.on_epoch_end(self.args, self.state)
self.callback_handler.on_train_end(self.args, self.state)
def save_model(self) -> None: def save_model(self) -> None:
"""Save the model.""" """Save the model."""
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"): if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
@@ -234,3 +284,5 @@ class BaseTrainer:
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB") model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB") self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
logger.info_rank0(f"Model saved to {self.args.output_dir}") logger.info_rank0(f"Model saved to {self.args.output_dir}")
self.callback_handler.on_save(self.args, self.state)

View File

@@ -0,0 +1,24 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .logging_callback import LoggingCallback
from .trainer_callback import CallbackHandler, TrainerCallback, TrainerState
__all__ = [
"CallbackHandler",
"LoggingCallback",
"TrainerCallback",
"TrainerState",
]

View File

@@ -0,0 +1,64 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import os
from typing import TYPE_CHECKING, Any
from .. import logging
from .trainer_callback import TrainerCallback, TrainerState
if TYPE_CHECKING:
from ...config import TrainingArguments
logger = logging.get_logger(__name__)
class LoggingCallback(TrainerCallback):
"""Logs training metrics to stdout on rank-0 and appends to ``state.log_history``.
On each logging step the entry is also persisted as a JSON line in
``<output_dir>/trainer_log.jsonl`` so that training history survives crashes.
"""
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
logs: dict[str, Any],
**kwargs: Any,
) -> None:
# Persist in history regardless of rank
state.log_history.append(dict(logs))
# Everything below is rank-0 only
from ...accelerator.interface import DistributedInterface # lazy import
if DistributedInterface().get_rank() != 0:
return
# Human-readable output to stdout
display_logs = {**logs, "total_steps": state.num_training_steps}
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
logger.info_rank0(parts)
# Append to JSONL log file in output_dir
os.makedirs(args.output_dir, exist_ok=True)
log_file = os.path.join(args.output_dir, "trainer_log.jsonl")
with open(log_file, "a", encoding="utf-8") as f:
f.write(json.dumps(display_logs, ensure_ascii=False) + "\n")

View File

@@ -0,0 +1,147 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from ...config import TrainingArguments
@dataclass
class TrainerState:
"""A read-only snapshot of training progress passed to every callback hook.
Attributes:
epoch: Current epoch (0-indexed).
global_step: Number of optimizer steps completed so far.
num_training_steps: Total number of optimizer steps planned.
loss: Scalar loss value of the most recent step.
grad_norm: Gradient-norm value of the most recent step.
learning_rate: Current learning rate seen by the optimizer.
log_history: List of per-step log dicts emitted by ``LoggingCallback``.
"""
epoch: int = 0
global_step: int = 0
num_training_steps: int = 0
loss: float = 0.0
grad_norm: float = 0.0
learning_rate: float = 0.0
log_history: list[dict[str, Any]] = field(default_factory=list)
class TrainerCallback:
"""Abstract base class for training callbacks.
Subclass and override whichever hooks you need. All hooks receive:
- ``args`` the :class:`~llamafactory.v1.config.TrainingArguments`.
- ``state`` a :class:`TrainerState` snapshot (read-only).
- ``**kwargs`` extra keyword arguments (model, optimizer, …).
Callbacks are *observers*: they should NOT mutate training flow.
Hook call order::
on_train_begin
for each epoch:
on_epoch_begin
for each step:
on_step_begin
(forward / backward / optimizer.step)
on_step_end
[on_log] ← if this step is a logging step
on_epoch_end
on_train_end
"""
def on_train_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called once before the first training step."""
def on_train_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called once after the last training step."""
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called at the beginning of each epoch."""
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called at the end of each epoch."""
def on_step_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called before the forward/backward pass of each optimizer step."""
def on_step_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called after the optimizer step."""
def on_log(self, args: TrainingArguments, state: TrainerState, logs: dict[str, Any], **kwargs: Any) -> None:
"""Called when the trainer emits a log entry."""
def on_save(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called after the model checkpoint has been written to disk."""
class CallbackHandler:
"""Owns a list of :class:`TrainerCallback` instances and fans out hook calls.
Usage::
handler = CallbackHandler([LoggingCallback(), MyWandbCallback()], trainer=trainer)
handler.on_train_begin(args, state)
"""
def __init__(self, callbacks: list[TrainerCallback] | None = None, trainer: Any = None) -> None:
self.callbacks: list[TrainerCallback] = list(callbacks or [])
self.trainer = trainer
def add_callback(self, callback: TrainerCallback) -> None:
"""Append a callback to the handler."""
self.callbacks.append(callback)
def _call(self, event: str, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
if self.trainer is not None:
kwargs.setdefault("model", getattr(self.trainer, "model", None))
kwargs.setdefault("optimizer", getattr(self.trainer, "optimizer", None))
kwargs.setdefault("lr_scheduler", getattr(self.trainer, "lr_scheduler", None))
kwargs.setdefault("train_dataloader", getattr(self.trainer, "train_batch_generator", None))
for cb in self.callbacks:
getattr(cb, event)(args, state, **kwargs)
def on_train_begin(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_train_begin", args, state)
def on_train_end(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_train_end", args, state)
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_epoch_begin", args, state)
def on_epoch_end(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_epoch_end", args, state)
def on_step_begin(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_step_begin", args, state)
def on_step_end(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_step_end", args, state)
def on_log(self, args: TrainingArguments, state: TrainerState, logs: dict[str, Any]) -> None:
self._call("on_log", args, state, logs=logs)
def on_save(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_save", args, state)