mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-03-22 09:53:24 +08:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
@@ -30,15 +31,15 @@ logger = logging.get_logger(__name__)
|
||||
class FeedbackDatasetProcessor(DatasetProcessor):
|
||||
def _encode_data_example(
|
||||
self,
|
||||
prompt: Sequence[Dict[str, str]],
|
||||
response: Sequence[Dict[str, str]],
|
||||
kl_response: Sequence[Dict[str, str]],
|
||||
prompt: Sequence[dict[str, str]],
|
||||
response: Sequence[dict[str, str]],
|
||||
kl_response: Sequence[dict[str, str]],
|
||||
system: Optional[str],
|
||||
tools: Optional[str],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
|
||||
) -> tuple[list[int], list[int], list[int], list[int], bool]:
|
||||
if response[0]["content"]: # desired example
|
||||
kto_tag = True
|
||||
messages = prompt + [response[0]]
|
||||
@@ -82,7 +83,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
|
||||
kl_labels = [IGNORE_INDEX] * kl_source_len + kl_response_ids
|
||||
return input_ids, labels, kl_input_ids, kl_labels, kto_tag
|
||||
|
||||
def preprocess_dataset(self, examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
def preprocess_dataset(self, examples: dict[str, list[Any]]) -> dict[str, list[Any]]:
|
||||
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
|
||||
kl_response = examples["_response"][::-1]
|
||||
model_inputs = defaultdict(list)
|
||||
@@ -121,7 +122,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
|
||||
|
||||
return model_inputs
|
||||
|
||||
def print_data_example(self, example: Dict[str, List[int]]) -> None:
|
||||
def print_data_example(self, example: dict[str, list[int]]) -> None:
|
||||
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(self.tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
Reference in New Issue
Block a user