fix Baichuan-13B

Former-commit-id: 6d9d826b3246349454c68f4d13b862da4de986e2
This commit is contained in:
hiyouga
2023-07-13 23:08:45 +08:00
parent d57e0a7006
commit 316a02696f
10 changed files with 24 additions and 83 deletions

View File

@@ -6,8 +6,6 @@ from .common import (
preprocess_data
)
from .data_collator import DynamicDataCollatorWithPadding
from .peft_trainer import PeftTrainer, LogCallback
from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer

View File

@@ -165,7 +165,7 @@ def load_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
padding_side="left",
padding_side=model_args.padding_side,
**config_kwargs
)
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)

View File

@@ -47,6 +47,10 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
)
padding_side: Optional[Literal["left", "right"]] = field(
default="left",
metadata={"help": "The side on which the model should have padding applied."}
)
quantization_bit: Optional[int] = field(
default=None,
metadata={"help": "The number of bits to quantize the model."}

View File

@@ -1,70 +0,0 @@
import torch
from typing import Dict, Optional, Sequence, Union
from transformers import DataCollatorWithPadding, BatchEncoding
from transformers.tokenization_utils import PreTrainedTokenizer
from .other import IGNORE_INDEX
class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
r"""
Inherits DataCollatorWithPadding. It is capable of dynamically padding for batched data.
"""
def __init__(
self,
tokenizer: PreTrainedTokenizer,
ignore_pad_token_for_loss: Optional[bool] = False
):
super().__init__(tokenizer, padding=True)
self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id
def get_attention_masks(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
r"""
Generates attention masks for left-padded sequences.
"""
batch_size, seq_length = input_ids.size()
attention_mask = torch.ones((batch_size, seq_length), device=device)
for i, seq in enumerate(input_ids):
attention_mask[i, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding
attention_mask = attention_mask.bool()
return attention_mask
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> BatchEncoding:
r"""
Pads batched data to the longest sequence in the batch.
We adopt left-padding in both training and evaluation.
"""
if isinstance(features[0]["input_ids"], torch.Tensor):
input_ids = [feature["input_ids"].clone().detach().flip(0) for feature in features]
else:
input_ids = [torch.tensor(feature["input_ids"]).flip(0) for feature in features]
if "labels" in features[0]:
if isinstance(features[0]["labels"], torch.Tensor):
labels = [feature["labels"].clone().detach().flip(0) for feature in features]
else:
labels = [torch.tensor(feature["labels"]).flip(0) for feature in features]
input_ids = input_ids + labels # pad them to the same length
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id
).flip(-1)
batch = {}
if "labels" in features[0]:
input_ids, labels = input_ids.split(len(features), dim=0)
labels = torch.where(labels != self.tokenizer.pad_token_id, labels, self.label_pad_token_id)
batch["labels"] = labels
batch["input_ids"] = input_ids
batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
return BatchEncoding(batch)

View File

@@ -2,7 +2,7 @@ import torch
import numpy as np
from typing import Dict, Sequence, Tuple, Union
from .data_collator import DynamicDataCollatorWithPadding
from transformers import DataCollatorWithPadding
from .peft_trainer import PeftTrainer
@@ -16,7 +16,7 @@ def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]])
return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
r"""
Data collator for pairwise data.
"""