mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-22 09:53:24 +08:00
refactor mm training
Former-commit-id: 179c0558699e287cbf38a2d73bff47e86d589c5a
This commit is contained in:
@@ -17,17 +17,10 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from .processor_utils import (
|
||||
get_paligemma_token_type_ids,
|
||||
get_pixel_values,
|
||||
get_qwen2vl_image_inputs,
|
||||
greedy_knapsack,
|
||||
infer_seqlen,
|
||||
)
|
||||
from .processor_utils import greedy_knapsack, infer_seqlen
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from PIL.Image import Image as ImageObject
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
|
||||
from ...hparams import DataArguments
|
||||
@@ -43,41 +36,15 @@ def _encode_supervised_example(
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
template: "Template",
|
||||
images: Sequence["ImageObject"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
cutoff_len: int,
|
||||
train_on_prompt: bool,
|
||||
mask_history: bool,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if processor is not None and "image_grid_thw" in processor.model_input_names: # qwen2_vl models
|
||||
image_processor = getattr(processor, "image_processor")
|
||||
merge_length = image_processor.merge_size**2
|
||||
if len(images) > 0:
|
||||
image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
|
||||
index = 0
|
||||
for message in prompt:
|
||||
content = message["content"]
|
||||
while "<|image_pad|>" in content:
|
||||
content = content.replace(
|
||||
"<|image_pad|>",
|
||||
template.vision_start_token
|
||||
+ "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length)
|
||||
+ template.vision_end_token,
|
||||
1,
|
||||
)
|
||||
index += 1
|
||||
message["content"] = content.replace("<|placeholder|>", "<|image_pad|>")
|
||||
elif processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||
|
||||
messages = prompt + response
|
||||
input_ids, labels = [], []
|
||||
|
||||
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
|
||||
input_ids += [image_token_id] * getattr(processor, "image_seq_length")
|
||||
labels += [IGNORE_INDEX] * getattr(processor, "image_seq_length")
|
||||
input_ids, labels = template.mm_plugin.process_token_ids(input_ids, labels, tokenizer, processor)
|
||||
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
|
||||
total_length = 1 if template.efficient_eos else 0
|
||||
@@ -125,28 +92,21 @@ def preprocess_supervised_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
) -> Dict[str, List[Any]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
if processor is not None:
|
||||
model_inputs["pixel_values"] = []
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"] = []
|
||||
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
|
||||
model_inputs["image_grid_thw"] = []
|
||||
|
||||
model_inputs = defaultdict(list)
|
||||
for i in range(len(examples["prompt"])):
|
||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
|
||||
continue
|
||||
|
||||
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
|
||||
input_ids, labels = _encode_supervised_example(
|
||||
prompt=examples["prompt"][i],
|
||||
prompt=prompt,
|
||||
response=examples["response"][i],
|
||||
system=examples["system"][i],
|
||||
tools=examples["tools"][i],
|
||||
images=examples["images"][i],
|
||||
template=template,
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
@@ -157,15 +117,12 @@ def preprocess_supervised_dataset(
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
if processor is not None:
|
||||
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
|
||||
image_inputs = get_qwen2vl_image_inputs(examples["images"][i], processor)
|
||||
model_inputs["pixel_values"].append(image_inputs["pixel_values"])
|
||||
model_inputs["image_grid_thw"].append(image_inputs["image_grid_thw"])
|
||||
else:
|
||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
||||
template.mm_plugin.process_model_inputs(
|
||||
model_inputs=model_inputs,
|
||||
images=examples["images"][i],
|
||||
feature_seqlens={"token_type_ids": len(input_ids)},
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return model_inputs
|
||||
|
||||
@@ -175,7 +132,7 @@ def preprocess_packed_supervised_dataset(
|
||||
template: "Template",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
data_args: "DataArguments",
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
) -> Dict[str, List[Any]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
valid_num = 0
|
||||
@@ -209,7 +166,7 @@ def preprocess_packed_supervised_dataset(
|
||||
batch_labels.append(labels)
|
||||
valid_num += 1
|
||||
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
model_inputs = defaultdict(list)
|
||||
knapsacks = greedy_knapsack(lengths, data_args.cutoff_len - 1) # reserved for the padding token
|
||||
for knapsack in knapsacks:
|
||||
packed_input_ids, packed_attention_masks, packed_labels = [], [], []
|
||||
|
||||
Reference in New Issue
Block a user