-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathadapter_inference_test.py
More file actions
70 lines (59 loc) · 2.06 KB
/
adapter_inference_test.py
File metadata and controls
70 lines (59 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from peft import PeftModel
import torch, readline
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import os
BASE = "meta-llama/Meta-Llama-3-8B-Instruct"
ADAPTER = ""
tokenizer = AutoTokenizer.from_pretrained(BASE, use_fast=True)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
BASE, torch_dtype=torch.bfloat16, device_map="auto"
)
adapter_path = Path(ADAPTER)
model = PeftModel.from_pretrained(
model,
str(adapter_path), # local path, NOT a HF repo id
is_trainable=False
)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
SYSTEM = "You are a CURIOUS agent focussed on understanding and creating cultural awareness. "
history = []
def build_prompt(user):
conv = [
{"role": "system", "content": SYSTEM},
*history,
{"role": "user", "content": user},
]
text = ""
for turn in conv:
if turn["role"] == "system":
text += f"<|system|>\n{turn['content']}\n"
elif turn["role"] == "user":
text += f"<|user|>\n{turn['content']}\n"
else:
text += f"<|assistant|>\n{turn['content']}\n"
text += "<|assistant|>\n"
return text
gen_cfg = dict(
max_new_tokens=512,
temperature=0.2,
top_p=0.9,
repetition_penalty=1.05,
do_sample=True,
)
while True:
try:
user = input("You: ")
if not user: continue
prompt = build_prompt(user)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
print("Model:", end=" ", flush=True)
out = model.generate(**inputs, streamer=streamer, **gen_cfg)
reply = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
history.append({"role":"user","content":user})
history.append({"role":"assistant","content":reply.strip()})
except KeyboardInterrupt:
print("\nbye"); break