Files
LlamaFactory/src/cli_demo.py
hiyouga fa850ae6e5 add ziya prompt template
Former-commit-id: 321e44ac54a91260cf00a4caa1991708814473fc
2023-06-03 19:05:51 +08:00

91 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# coding=utf-8
# Implements stream chat in command line for fine-tuned models.
# Usage: python cli_demo.py --checkpoint_dir path_to_checkpoint
import torch
from utils import (
load_pretrained,
prepare_infer_args,
get_logits_processor
)
def main():
model_args, data_args, finetuning_args = prepare_infer_args()
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
model, tokenizer = load_pretrained(model_args, finetuning_args)
if torch.cuda.device_count() > 1:
from accelerate import dispatch_model, infer_auto_device_map
device_map = infer_auto_device_map(model)
model = dispatch_model(model, device_map)
else:
model = model.cuda()
model.eval()
def format_example_alpaca(query, history):
prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\n"
for old_query, response in history:
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
prompt += "Human: {}\nAssistant:".format(query)
return prompt
def format_example_ziya(query, history):
prompt = ""
for old_query, response in history:
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
prompt += "<human>: {}\n<bot>:".format(query)
return prompt
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
def predict(query, history: list):
input_ids = tokenizer([format_example(query, history)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
gen_kwargs = {
"do_sample": True,
"top_p": 0.7,
"temperature": 0.95,
"num_beams": 1,
"max_new_tokens": 256,
"repetition_penalty": 1.5,
"logits_processor": get_logits_processor()
}
with torch.no_grad():
generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
outputs = generation_output.tolist()[0][len(input_ids[0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True)
history = history + [(query, response)]
return response, history
history = []
print("欢迎使用 {} 模型输入内容即可对话clear清空对话历史stop终止程序".format(model_name))
while True:
try:
query = input("\nInput: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "stop":
break
if query.strip() == "clear":
history = []
print("History has been removed.")
continue
response, history = predict(query, history)
print("{}:".format(model_name), response)
if __name__ == "__main__":
main()