diff --git a/src/llamafactory/v1/config/training_args.py b/src/llamafactory/v1/config/training_args.py index 5d13ef2fb..750899ccd 100644 --- a/src/llamafactory/v1/config/training_args.py +++ b/src/llamafactory/v1/config/training_args.py @@ -85,6 +85,10 @@ class TrainingArguments: default=42, metadata={"help": "Random seed that will be set at the beginning of training."}, ) + logging_steps: int = field( + default=1, + metadata={"help": "Log metrics every N optimizer steps."}, + ) def __post_init__(self) -> None: self.dist_config = get_plugin_config(self.dist_config) diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index 3e277e435..cee4abb81 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -36,6 +36,12 @@ from ..accelerator.helper import ReduceOp from ..accelerator.interface import Dim, DistributedInterface from ..config import TrainingArguments from ..utils import logging +from ..utils.callbacks import ( + CallbackHandler, + LoggingCallback, + TrainerCallback, + TrainerState, +) from ..utils.helper import compute_valid_tokens from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset from .utils.batching import BatchGenerator @@ -52,6 +58,7 @@ class BaseTrainer: model: HFModel, renderer: Renderer, train_dataset: TorchDataset, + callbacks: list[TrainerCallback] | None = None, ) -> None: self.args = args self.model = model @@ -99,6 +106,14 @@ class BaseTrainer: self._init_optimizer() self._init_lr_scheduler() + # Callbacks + self.callback_handler = CallbackHandler([LoggingCallback()], trainer=self) + for cb in callbacks or []: + self.callback_handler.add_callback(cb) + + # Callbacks: TrainerState tracks progress across the full run. + self.state = TrainerState(num_training_steps=self.num_training_steps) + def _create_batch_generator(self) -> None: self.train_batch_generator = BatchGenerator( dataset=self.train_dataset, @@ -174,10 +189,18 @@ class BaseTrainer: def fit(self) -> None: """Train the model.""" self.model.train() + self.callback_handler.on_train_begin(self.args, self.state) for epoch in range(self.args.num_train_epochs): + self.state.epoch = epoch self.train_batch_generator.set_epoch(epoch) + self.callback_handler.on_epoch_begin(self.args, self.state) + for micro_batches in self.train_batch_generator: self.global_step += 1 + + self.state.global_step = self.global_step + self.callback_handler.on_step_begin(self.args, self.state) + step_loss = 0 step_valid_tokens = compute_valid_tokens(micro_batches) step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM) @@ -213,14 +236,41 @@ class BaseTrainer: step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm]) DistributedInterface().sync() - if DistributedInterface().get_rank() == 0: - print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}") + + # Update state with step metrics + current_lr = ( + self.lr_scheduler.get_last_lr()[0] + if hasattr(self.lr_scheduler, "get_last_lr") + else self.args.learning_rate + ) + self.state.loss = step_loss + self.state.grad_norm = grad_norm + self.state.learning_rate = current_lr + + self.callback_handler.on_step_end(self.args, self.state) + + # Logging: trainer decides when to log + if self.global_step % self.args.logging_steps == 0: + logs = { + "epoch": epoch, + "step": self.global_step, + "loss": step_loss, + "grad_norm": grad_norm, + "learning_rate": current_lr, + } + self.callback_handler.on_log(self.args, self.state, logs) # Check if max_steps is reached if self.global_step >= self.num_training_steps: logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.") + self.callback_handler.on_epoch_end(self.args, self.state) + self.callback_handler.on_train_end(self.args, self.state) return + self.callback_handler.on_epoch_end(self.args, self.state) + + self.callback_handler.on_train_end(self.args, self.state) + def save_model(self) -> None: """Save the model.""" if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"): @@ -234,3 +284,5 @@ class BaseTrainer: model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB") self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB") logger.info_rank0(f"Model saved to {self.args.output_dir}") + + self.callback_handler.on_save(self.args, self.state) diff --git a/src/llamafactory/v1/utils/callbacks/__init__.py b/src/llamafactory/v1/utils/callbacks/__init__.py new file mode 100644 index 000000000..0da31ed3f --- /dev/null +++ b/src/llamafactory/v1/utils/callbacks/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .logging_callback import LoggingCallback +from .trainer_callback import CallbackHandler, TrainerCallback, TrainerState + + +__all__ = [ + "CallbackHandler", + "LoggingCallback", + "TrainerCallback", + "TrainerState", +] diff --git a/src/llamafactory/v1/utils/callbacks/logging_callback.py b/src/llamafactory/v1/utils/callbacks/logging_callback.py new file mode 100644 index 000000000..d6bdba604 --- /dev/null +++ b/src/llamafactory/v1/utils/callbacks/logging_callback.py @@ -0,0 +1,64 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import os +from typing import TYPE_CHECKING, Any + +from .. import logging +from .trainer_callback import TrainerCallback, TrainerState + + +if TYPE_CHECKING: + from ...config import TrainingArguments + + +logger = logging.get_logger(__name__) + + +class LoggingCallback(TrainerCallback): + """Logs training metrics to stdout on rank-0 and appends to ``state.log_history``. + + On each logging step the entry is also persisted as a JSON line in + ``/trainer_log.jsonl`` so that training history survives crashes. + """ + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + logs: dict[str, Any], + **kwargs: Any, + ) -> None: + # Persist in history regardless of rank + state.log_history.append(dict(logs)) + + # Everything below is rank-0 only + from ...accelerator.interface import DistributedInterface # lazy import + + if DistributedInterface().get_rank() != 0: + return + + # Human-readable output to stdout + display_logs = {**logs, "total_steps": state.num_training_steps} + parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items()) + logger.info_rank0(parts) + + # Append to JSONL log file in output_dir + os.makedirs(args.output_dir, exist_ok=True) + log_file = os.path.join(args.output_dir, "trainer_log.jsonl") + with open(log_file, "a", encoding="utf-8") as f: + f.write(json.dumps(display_logs, ensure_ascii=False) + "\n") diff --git a/src/llamafactory/v1/utils/callbacks/trainer_callback.py b/src/llamafactory/v1/utils/callbacks/trainer_callback.py new file mode 100644 index 000000000..400514d29 --- /dev/null +++ b/src/llamafactory/v1/utils/callbacks/trainer_callback.py @@ -0,0 +1,147 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from ...config import TrainingArguments + + +@dataclass +class TrainerState: + """A read-only snapshot of training progress passed to every callback hook. + + Attributes: + epoch: Current epoch (0-indexed). + global_step: Number of optimizer steps completed so far. + num_training_steps: Total number of optimizer steps planned. + loss: Scalar loss value of the most recent step. + grad_norm: Gradient-norm value of the most recent step. + learning_rate: Current learning rate seen by the optimizer. + log_history: List of per-step log dicts emitted by ``LoggingCallback``. + """ + + epoch: int = 0 + global_step: int = 0 + num_training_steps: int = 0 + loss: float = 0.0 + grad_norm: float = 0.0 + learning_rate: float = 0.0 + log_history: list[dict[str, Any]] = field(default_factory=list) + + +class TrainerCallback: + """Abstract base class for training callbacks. + + Subclass and override whichever hooks you need. All hooks receive: + + - ``args`` – the :class:`~llamafactory.v1.config.TrainingArguments`. + - ``state`` – a :class:`TrainerState` snapshot (read-only). + - ``**kwargs`` – extra keyword arguments (model, optimizer, …). + + Callbacks are *observers*: they should NOT mutate training flow. + + Hook call order:: + + on_train_begin + for each epoch: + on_epoch_begin + for each step: + on_step_begin + (forward / backward / optimizer.step) + on_step_end + [on_log] ← if this step is a logging step + on_epoch_end + on_train_end + """ + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None: + """Called once before the first training step.""" + + def on_train_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None: + """Called once after the last training step.""" + + def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None: + """Called at the beginning of each epoch.""" + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None: + """Called at the end of each epoch.""" + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None: + """Called before the forward/backward pass of each optimizer step.""" + + def on_step_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None: + """Called after the optimizer step.""" + + def on_log(self, args: TrainingArguments, state: TrainerState, logs: dict[str, Any], **kwargs: Any) -> None: + """Called when the trainer emits a log entry.""" + + def on_save(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None: + """Called after the model checkpoint has been written to disk.""" + + +class CallbackHandler: + """Owns a list of :class:`TrainerCallback` instances and fans out hook calls. + + Usage:: + + handler = CallbackHandler([LoggingCallback(), MyWandbCallback()], trainer=trainer) + handler.on_train_begin(args, state) + """ + + def __init__(self, callbacks: list[TrainerCallback] | None = None, trainer: Any = None) -> None: + self.callbacks: list[TrainerCallback] = list(callbacks or []) + self.trainer = trainer + + def add_callback(self, callback: TrainerCallback) -> None: + """Append a callback to the handler.""" + self.callbacks.append(callback) + + def _call(self, event: str, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None: + if self.trainer is not None: + kwargs.setdefault("model", getattr(self.trainer, "model", None)) + kwargs.setdefault("optimizer", getattr(self.trainer, "optimizer", None)) + kwargs.setdefault("lr_scheduler", getattr(self.trainer, "lr_scheduler", None)) + kwargs.setdefault("train_dataloader", getattr(self.trainer, "train_batch_generator", None)) + + for cb in self.callbacks: + getattr(cb, event)(args, state, **kwargs) + + def on_train_begin(self, args: TrainingArguments, state: TrainerState) -> None: + self._call("on_train_begin", args, state) + + def on_train_end(self, args: TrainingArguments, state: TrainerState) -> None: + self._call("on_train_end", args, state) + + def on_epoch_begin(self, args: TrainingArguments, state: TrainerState) -> None: + self._call("on_epoch_begin", args, state) + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState) -> None: + self._call("on_epoch_end", args, state) + + def on_step_begin(self, args: TrainingArguments, state: TrainerState) -> None: + self._call("on_step_begin", args, state) + + def on_step_end(self, args: TrainingArguments, state: TrainerState) -> None: + self._call("on_step_end", args, state) + + def on_log(self, args: TrainingArguments, state: TrainerState, logs: dict[str, Any]) -> None: + self._call("on_log", args, state, logs=logs) + + def on_save(self, args: TrainingArguments, state: TrainerState) -> None: + self._call("on_save", args, state)