[model] support audio (#6701)

* support qwen2_audio

* improve code

* lint

* fix

* fix

* fix

---------

Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
Former-commit-id: 5eacb5629e4d7733cd992a63747a1335f2c6a929
This commit is contained in:
Zhangchi Feng
2025-02-05 04:59:09 +08:00
committed by GitHub
parent 9feb78e7b4
commit 8f401e37f8
35 changed files with 675 additions and 213 deletions

View File

@@ -25,57 +25,33 @@ if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments
from ..hparams import DataArguments
from .mm_plugin import ImageInput, VideoInput
from .parser import DatasetAttr
logger = logging.get_logger(__name__)
def _convert_images(
images: Union["ImageInput", Sequence["ImageInput"]],
def _regularize_medias(
inputs: Union[Any, Sequence[Any]],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["ImageInput"]]:
) -> Optional[List[Any]]:
r"""
Optionally concatenates image path to dataset dir when loading from local disk.
Optionally concatenates media path to media dir when loading from local disk.
"""
if not isinstance(images, list):
images = [images]
elif len(images) == 0:
if not isinstance(inputs, list):
inputs = [inputs]
elif len(inputs) == 0:
return None
else:
images = images[:]
inputs = inputs[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(images)):
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
images[i] = os.path.join(data_args.image_dir, images[i])
for i in range(len(inputs)):
if isinstance(inputs[i], str) and os.path.isfile(os.path.join(data_args.media_dir, inputs[i])):
inputs[i] = os.path.join(data_args.media_dir, inputs[i])
return images
def _convert_videos(
videos: Union["VideoInput", Sequence["VideoInput"]],
dataset_attr: "DatasetAttr",
data_args: "DataArguments",
) -> Optional[List["VideoInput"]]:
r"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if not isinstance(videos, list):
videos = [videos]
elif len(videos) == 0:
return None
else:
videos = videos[:]
if dataset_attr.load_from in ["script", "file"]:
for i in range(len(videos)):
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
videos[i] = os.path.join(data_args.image_dir, videos[i])
return videos
return inputs
def convert_alpaca(
@@ -121,15 +97,15 @@ def convert_alpaca(
else: # unsupervised
response = []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
regularize_medias = partial(_regularize_medias, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": example[dataset_attr.system] if dataset_attr.system else "",
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
"_images": regularize_medias(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": regularize_medias(example[dataset_attr.videos]) if dataset_attr.videos else None,
"_audios": regularize_medias(example[dataset_attr.audios]) if dataset_attr.audios else None,
}
return output
@@ -214,15 +190,15 @@ def convert_sharegpt(
logger.warning_rank0("Skipping this abnormal example.")
prompt, response = [], []
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
regularize_medias = partial(_regularize_medias, dataset_attr=dataset_attr, data_args=data_args)
output = {
"_prompt": prompt,
"_response": response,
"_system": system,
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
"_images": regularize_medias(example[dataset_attr.images]) if dataset_attr.images else None,
"_videos": regularize_medias(example[dataset_attr.videos]) if dataset_attr.videos else None,
"_audios": regularize_medias(example[dataset_attr.audios]) if dataset_attr.audios else None,
}
return output
@@ -241,6 +217,7 @@ def align_dataset(
_tools: "...",
_images: [],
_videos: [],
_audios: [],
"""
if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)