[fix] fit neat_packing & mrope model packing (#10283)

Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Kingsley
2026-03-20 16:50:11 +08:00
committed by GitHub
parent d91d8af89e
commit 833f6027b1
15 changed files with 520 additions and 93 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging
@@ -27,6 +27,23 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index
@dataclass
class PackingParams:
r"""Metadata for a packed sequence: sub-sequence boundaries and multimodal data indices.
- sequence_boundaries: cumulative token positions, e.g. [0, 100, 250, 512] means 3 sub-seqs
with token ranges [0,100), [100,250), [250,512). Length = num_sub_seqs + 1.
- image_subseq_ids / video_subseq_ids / audio_subseq_ids: for each mm item, the 0-based
sub-sequence index it belongs to. Length = total number of that mm type in the packed sample.
"""
sequence_boundaries: list[int]
image_subseq_ids: list[int]
video_subseq_ids: list[int]
audio_subseq_ids: list[int]
right_padding_length: int
@dataclass
class SupervisedDatasetProcessor(DatasetProcessor):
@@ -162,10 +179,17 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
valid_num += 1
model_inputs = defaultdict(list)
requires_packing_params = self.data_args.neat_packing
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
for knapsack in knapsacks:
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
packed_images, packed_videos, packed_audios = [], [], []
if requires_packing_params:
sequence_boundaries = [0]
image_subseq_ids: list[int] = []
video_subseq_ids: list[int] = []
audio_subseq_ids: list[int] = []
for i, length in enumerate(knapsack):
index = length2indexes[length].pop()
packed_input_ids += batch_input_ids[index]
@@ -174,6 +198,15 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_images += batch_images[index]
packed_videos += batch_videos[index]
packed_audios += batch_audios[index]
if requires_packing_params:
n_img = len(batch_images[index])
n_vid = len(batch_videos[index])
n_aud = len(batch_audios[index])
sequence_boundaries.append(sequence_boundaries[-1] + len(batch_input_ids[index]))
image_subseq_ids.extend([i] * n_img)
video_subseq_ids.extend([i] * n_vid)
audio_subseq_ids.extend([i] * n_aud)
if self.data_args.neat_packing:
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
else:
@@ -189,10 +222,23 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
else:
packed_attention_masks += [1] * pad_length # more efficient flash_attn
if requires_packing_params:
sequence_boundaries.append(sequence_boundaries[-1] + pad_length)
if len(packed_input_ids) != self.data_args.cutoff_len + 1:
raise ValueError("The length of packed example should be identical to the cutoff length.")
model_inputs["input_ids"].append(packed_input_ids)
if requires_packing_params:
packing_params = PackingParams(
sequence_boundaries=sequence_boundaries,
image_subseq_ids=image_subseq_ids or [MAX_SU_SEQ_IDX], # avoid dataset concat error
video_subseq_ids=video_subseq_ids or [MAX_SU_SEQ_IDX],
audio_subseq_ids=audio_subseq_ids or [MAX_SU_SEQ_IDX],
right_padding_length=pad_length,
)
model_inputs["packing_params"].append(asdict(packing_params))
model_inputs["attention_mask"].append(packed_attention_masks)
model_inputs["position_ids"].append(packed_position_ids)
model_inputs["labels"].append(packed_labels)