mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-22 18:03:23 +08:00
[fix] fit neat_packing & mrope model packing (#10283)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
@@ -27,6 +27,23 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index
|
||||
|
||||
@dataclass
|
||||
class PackingParams:
|
||||
r"""Metadata for a packed sequence: sub-sequence boundaries and multimodal data indices.
|
||||
|
||||
- sequence_boundaries: cumulative token positions, e.g. [0, 100, 250, 512] means 3 sub-seqs
|
||||
with token ranges [0,100), [100,250), [250,512). Length = num_sub_seqs + 1.
|
||||
- image_subseq_ids / video_subseq_ids / audio_subseq_ids: for each mm item, the 0-based
|
||||
sub-sequence index it belongs to. Length = total number of that mm type in the packed sample.
|
||||
"""
|
||||
|
||||
sequence_boundaries: list[int]
|
||||
image_subseq_ids: list[int]
|
||||
video_subseq_ids: list[int]
|
||||
audio_subseq_ids: list[int]
|
||||
right_padding_length: int
|
||||
|
||||
@dataclass
|
||||
class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
@@ -162,10 +179,17 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
valid_num += 1
|
||||
|
||||
model_inputs = defaultdict(list)
|
||||
requires_packing_params = self.data_args.neat_packing
|
||||
knapsacks = greedy_knapsack(lengths, self.data_args.cutoff_len)
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_attention_masks, packed_position_ids, packed_labels = [], [], [], []
|
||||
packed_images, packed_videos, packed_audios = [], [], []
|
||||
if requires_packing_params:
|
||||
sequence_boundaries = [0]
|
||||
image_subseq_ids: list[int] = []
|
||||
video_subseq_ids: list[int] = []
|
||||
audio_subseq_ids: list[int] = []
|
||||
|
||||
for i, length in enumerate(knapsack):
|
||||
index = length2indexes[length].pop()
|
||||
packed_input_ids += batch_input_ids[index]
|
||||
@@ -174,6 +198,15 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
packed_images += batch_images[index]
|
||||
packed_videos += batch_videos[index]
|
||||
packed_audios += batch_audios[index]
|
||||
if requires_packing_params:
|
||||
n_img = len(batch_images[index])
|
||||
n_vid = len(batch_videos[index])
|
||||
n_aud = len(batch_audios[index])
|
||||
sequence_boundaries.append(sequence_boundaries[-1] + len(batch_input_ids[index]))
|
||||
image_subseq_ids.extend([i] * n_img)
|
||||
video_subseq_ids.extend([i] * n_vid)
|
||||
audio_subseq_ids.extend([i] * n_aud)
|
||||
|
||||
if self.data_args.neat_packing:
|
||||
packed_attention_masks += [i + 1] * len(batch_input_ids[index]) # start from 1
|
||||
else:
|
||||
@@ -189,10 +222,23 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
|
||||
else:
|
||||
packed_attention_masks += [1] * pad_length # more efficient flash_attn
|
||||
|
||||
if requires_packing_params:
|
||||
sequence_boundaries.append(sequence_boundaries[-1] + pad_length)
|
||||
|
||||
if len(packed_input_ids) != self.data_args.cutoff_len + 1:
|
||||
raise ValueError("The length of packed example should be identical to the cutoff length.")
|
||||
|
||||
model_inputs["input_ids"].append(packed_input_ids)
|
||||
if requires_packing_params:
|
||||
packing_params = PackingParams(
|
||||
sequence_boundaries=sequence_boundaries,
|
||||
image_subseq_ids=image_subseq_ids or [MAX_SU_SEQ_IDX], # avoid dataset concat error
|
||||
video_subseq_ids=video_subseq_ids or [MAX_SU_SEQ_IDX],
|
||||
audio_subseq_ids=audio_subseq_ids or [MAX_SU_SEQ_IDX],
|
||||
right_padding_length=pad_length,
|
||||
)
|
||||
model_inputs["packing_params"].append(asdict(packing_params))
|
||||
|
||||
model_inputs["attention_mask"].append(packed_attention_masks)
|
||||
model_inputs["position_ids"].append(packed_position_ids)
|
||||
model_inputs["labels"].append(packed_labels)
|
||||
|
||||
Reference in New Issue
Block a user