[v1] add init on rank0 for fsdp2 (#10264)

This commit is contained in:
jiaqiw09
2026-03-27 14:54:03 +08:00
committed by GitHub
parent d02fcd3588
commit df2e6edb7e
9 changed files with 84 additions and 12 deletions

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-4B model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-0.6B model: Qwen/Qwen3-0.6B
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-0.6B model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-4B model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink
@@ -28,7 +27,6 @@ train_dataset: data/v1_sft_demo.yaml
### training ### training
output_dir: ./outputs/test_lora output_dir: ./outputs/test_lora
micro_batch_size: 1 micro_batch_size: 1
global_batch_size: 4
cutoff_len: 2048 cutoff_len: 2048
learning_rate: 1.0e-4 learning_rate: 1.0e-4
bf16: true bf16: true

View File

@@ -0,0 +1,40 @@
model: Qwen/Qwen3-4B
model_class: llm
template: qwen3_nothink
# PEFT Configuration
peft_config:
name: lora
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: all
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
init_config:
name: init_on_rank0
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: ./outputs/test_lora
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-0.6B model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink

View File

@@ -140,6 +140,9 @@ class ModelEngine:
**init_kwargs, **init_kwargs,
) )
init_mode = self.args.init_config.name if self.args.init_config is not None else "init_on_default"
model._init_mode = init_mode
if self.args.peft_config is None: if self.args.peft_config is None:
if self.is_train: if self.is_train:
logger.info_rank0("Fine-tuning mode: full tuning") logger.info_rank0("Fine-tuning mode: full tuning")
@@ -147,6 +150,9 @@ class ModelEngine:
else: else:
logger.info_rank0("Inference the original model") logger.info_rank0("Inference the original model")
else: else:
if self.args.peft_config.name == "lora" and init_mode == "init_on_meta":
raise ValueError("Currently lora stage does not support loading model by meta.")
from ..plugins.model_plugins.peft import PeftPlugin from ..plugins.model_plugins.peft import PeftPlugin
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train) model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)

View File

@@ -150,9 +150,6 @@ def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is
@PeftPlugin("lora").register() @PeftPlugin("lora").register()
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel: def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
if model.device.type == "meta":
raise ValueError("Currently lora stage does not support loading model by meta.")
adapter_name_or_path = config.get("adapter_name_or_path") adapter_name_or_path = config.get("adapter_name_or_path")
if adapter_name_or_path: if adapter_name_or_path:

View File

@@ -17,6 +17,7 @@ import gc
import os import os
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from peft.tuners.lora import LoraLayer from peft.tuners.lora import LoraLayer
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
@@ -244,23 +245,57 @@ class FSDP2Engine:
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers") logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
def shard_model(self, model: HFModel) -> HFModel: def shard_model(self, model: HFModel) -> HFModel:
if model.device.type == "meta": init_mode = getattr(model, "_init_mode", "init_on_default")
if init_mode == "init_on_rank0":
if getattr(model.config, "tie_word_embeddings", False):
model.tie_weights()
if self.rank == 0:
logger.info("init_on_rank0 detected: sharding then scattering Rank 0 CPU weights.")
full_sd = {k: v.clone() for k, v in model.state_dict().items()}
else:
full_sd = {}
# Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard
saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {}
model = self.prepare_model(model)
device = get_current_accelerator()
model.to_empty(device=device)
# Scatter params from Rank 0 into all DTensor shards
# Broadcast the full state dict from the global rank-0 process to all ranks in this group.
options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
set_model_state_dict(model, full_sd, options=options)
# Broadcast and restore non-persistent buffers
buffers_to_sync = [saved_buffers]
dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group())
self._restore_non_persistent_buffers(model, buffers_to_sync[0])
if self.rank == 0:
logger.info("init_on_rank0 sync complete.")
elif init_mode == "init_on_meta":
non_persistent_buffers = self._save_non_persistent_buffers(model) non_persistent_buffers = self._save_non_persistent_buffers(model)
if getattr(model.config, "tie_word_embeddings", None): if getattr(model.config, "tie_word_embeddings", False):
model.tie_weights() model.tie_weights()
model = self.prepare_model(model) model = self.prepare_model(model)
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path) model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
# fix tied broken for no-fsdp-wrap case # fix tied broken for no-fsdp-wrap case
if getattr(model.config, "tie_word_embeddings", None): if getattr(model.config, "tie_word_embeddings", False):
model.tie_weights() model.tie_weights()
self._restore_non_persistent_buffers(model, non_persistent_buffers) self._restore_non_persistent_buffers(model, non_persistent_buffers)
else: else:
model = self.prepare_model(model) model = self.prepare_model(model)
return model return model
def _load_from_dcp(self, model: HFModel, dcp_path: str): def _load_from_dcp(self, model: HFModel, dcp_path: str):