mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-23 02:33:24 +08:00
[model] support audio (#6701)
* support qwen2_audio * improve code * lint * fix * fix * fix --------- Co-authored-by: hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 5eacb5629e4d7733cd992a63747a1335f2c6a929
This commit is contained in:
@@ -25,57 +25,33 @@ if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from ..hparams import DataArguments
|
||||
from .mm_plugin import ImageInput, VideoInput
|
||||
from .parser import DatasetAttr
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_images(
|
||||
images: Union["ImageInput", Sequence["ImageInput"]],
|
||||
def _regularize_medias(
|
||||
inputs: Union[Any, Sequence[Any]],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
) -> Optional[List["ImageInput"]]:
|
||||
) -> Optional[List[Any]]:
|
||||
r"""
|
||||
Optionally concatenates image path to dataset dir when loading from local disk.
|
||||
Optionally concatenates media path to media dir when loading from local disk.
|
||||
"""
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
elif len(images) == 0:
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
elif len(inputs) == 0:
|
||||
return None
|
||||
else:
|
||||
images = images[:]
|
||||
inputs = inputs[:]
|
||||
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for i in range(len(images)):
|
||||
if isinstance(images[i], str) and os.path.isfile(os.path.join(data_args.image_dir, images[i])):
|
||||
images[i] = os.path.join(data_args.image_dir, images[i])
|
||||
for i in range(len(inputs)):
|
||||
if isinstance(inputs[i], str) and os.path.isfile(os.path.join(data_args.media_dir, inputs[i])):
|
||||
inputs[i] = os.path.join(data_args.media_dir, inputs[i])
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def _convert_videos(
|
||||
videos: Union["VideoInput", Sequence["VideoInput"]],
|
||||
dataset_attr: "DatasetAttr",
|
||||
data_args: "DataArguments",
|
||||
) -> Optional[List["VideoInput"]]:
|
||||
r"""
|
||||
Optionally concatenates video path to dataset dir when loading from local disk.
|
||||
"""
|
||||
if not isinstance(videos, list):
|
||||
videos = [videos]
|
||||
elif len(videos) == 0:
|
||||
return None
|
||||
else:
|
||||
videos = videos[:]
|
||||
|
||||
if dataset_attr.load_from in ["script", "file"]:
|
||||
for i in range(len(videos)):
|
||||
if isinstance(videos[i], str) and os.path.isfile(os.path.join(data_args.image_dir, videos[i])):
|
||||
videos[i] = os.path.join(data_args.image_dir, videos[i])
|
||||
|
||||
return videos
|
||||
return inputs
|
||||
|
||||
|
||||
def convert_alpaca(
|
||||
@@ -121,15 +97,15 @@ def convert_alpaca(
|
||||
else: # unsupervised
|
||||
response = []
|
||||
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
|
||||
regularize_medias = partial(_regularize_medias, dataset_attr=dataset_attr, data_args=data_args)
|
||||
output = {
|
||||
"_prompt": prompt,
|
||||
"_response": response,
|
||||
"_system": example[dataset_attr.system] if dataset_attr.system else "",
|
||||
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
|
||||
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
|
||||
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
|
||||
"_images": regularize_medias(example[dataset_attr.images]) if dataset_attr.images else None,
|
||||
"_videos": regularize_medias(example[dataset_attr.videos]) if dataset_attr.videos else None,
|
||||
"_audios": regularize_medias(example[dataset_attr.audios]) if dataset_attr.audios else None,
|
||||
}
|
||||
return output
|
||||
|
||||
@@ -214,15 +190,15 @@ def convert_sharegpt(
|
||||
logger.warning_rank0("Skipping this abnormal example.")
|
||||
prompt, response = [], []
|
||||
|
||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||
convert_videos = partial(_convert_videos, dataset_attr=dataset_attr, data_args=data_args)
|
||||
regularize_medias = partial(_regularize_medias, dataset_attr=dataset_attr, data_args=data_args)
|
||||
output = {
|
||||
"_prompt": prompt,
|
||||
"_response": response,
|
||||
"_system": system,
|
||||
"_tools": example[dataset_attr.tools] if dataset_attr.tools else "",
|
||||
"_images": convert_images(example[dataset_attr.images]) if dataset_attr.images else None,
|
||||
"_videos": convert_videos(example[dataset_attr.videos]) if dataset_attr.videos else None,
|
||||
"_images": regularize_medias(example[dataset_attr.images]) if dataset_attr.images else None,
|
||||
"_videos": regularize_medias(example[dataset_attr.videos]) if dataset_attr.videos else None,
|
||||
"_audios": regularize_medias(example[dataset_attr.audios]) if dataset_attr.audios else None,
|
||||
}
|
||||
return output
|
||||
|
||||
@@ -241,6 +217,7 @@ def align_dataset(
|
||||
_tools: "...",
|
||||
_images: [],
|
||||
_videos: [],
|
||||
_audios: [],
|
||||
"""
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
|
||||
|
||||
Reference in New Issue
Block a user