Skip to content

Commit 7ebb0b2

Browse files
committed
Fix logprobs when multiple tokens are returned at once.
1 parent d03752e commit 7ebb0b2

4 files changed

Lines changed: 80 additions & 38 deletions

File tree

backends/exllamav2/model.py

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -767,21 +767,25 @@ def get_special_tokens(
767767
}
768768

769769
def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
770-
top_tokens = [
771-
self.tokenizer.extended_id_to_piece.get(
772-
index, self.tokenizer.id_to_piece[index]
773-
)
774-
for index in token_ids.flatten().tolist()
775-
]
770+
logprobs = []
771+
for token_idx in range(token_ids.shape[1]):
772+
top_tokens = [
773+
self.tokenizer.extended_id_to_piece.get(
774+
index, self.tokenizer.id_to_piece[index]
775+
)
776+
for index in token_ids[0, token_idx].tolist()
777+
]
776778

777-
top_values = torch.log(token_probs).flatten().tolist()
779+
top_values = torch.log(token_probs[0, token_idx]).tolist()
778780

779-
# Cannot return -inf in JSON
780-
cleaned_values = [
781-
-1000 if value == float("-inf") else value for value in top_values
782-
]
781+
# Cannot return -inf in JSON
782+
cleaned_values = [
783+
-1000 if value == float("-inf") else value for value in top_values
784+
]
783785

784-
return dict(zip_longest(top_tokens, cleaned_values))
786+
logprobs.append(dict(zip_longest(top_tokens, cleaned_values)))
787+
788+
return logprobs
785789

786790
async def generate(self, prompt: str, **kwargs):
787791
"""Generate a response to a prompt"""
@@ -793,8 +797,9 @@ async def generate(self, prompt: str, **kwargs):
793797
"text": "",
794798
"prompt_tokens": 0,
795799
"generation_tokens": 0,
800+
"tokens": [],
796801
"offset": [],
797-
"token_probs": {},
802+
"token_probs": [],
798803
"logprobs": [],
799804
}
800805

@@ -811,13 +816,14 @@ async def generate(self, prompt: str, **kwargs):
811816
if len(generations) > 0:
812817
for generation in generations:
813818
joined_generation["text"] += unwrap(generation.get("text"), "")
814-
joined_generation["offset"].append(unwrap(generation.get("offset"), -1))
815-
joined_generation["token_probs"].update(
816-
unwrap(generation.get("token_probs"), {})
819+
joined_generation["tokens"].extend(unwrap(generation.get("tokens"), []))
820+
joined_generation["offset"].extend(unwrap(generation.get("offset"), []))
821+
joined_generation["token_probs"].extend(
822+
unwrap(generation.get("token_probs"), [])
817823
)
818824

819825
# Include empty logprob dicts for index preservation
820-
joined_generation["logprobs"].append(
826+
joined_generation["logprobs"].extend(
821827
unwrap(generation.get("logprobs"), {})
822828
)
823829

@@ -1145,7 +1151,6 @@ async def generate_gen(
11451151
"text": chunk,
11461152
"prompt_tokens": context_len,
11471153
"generated_tokens": generated_tokens,
1148-
"offset": len(full_response),
11491154
}
11501155

11511156
if request_logprobs > 0:
@@ -1164,11 +1169,41 @@ async def generate_gen(
11641169
logprobs = self.get_logprobs(top_tokens, top_probs)
11651170
generation["logprobs"] = logprobs
11661171

1167-
# The first logprob is the selected token prob
1168-
generation["token_probs"] = {
1169-
token: logprobs[token]
1170-
for token in list(logprobs.keys())[:1]
1171-
}
1172+
token_ids = unwrap(
1173+
result.get("token_ids"),
1174+
torch.empty(0),
1175+
)
1176+
1177+
token_probs = unwrap(
1178+
result.get("token_probs"),
1179+
torch.empty(0),
1180+
)
1181+
1182+
if token_ids.numel() > 0 and token_probs.numel() > 0:
1183+
token_ids = token_ids.flatten().tolist()
1184+
token_probs = token_probs.flatten().tolist()
1185+
1186+
tokens = [
1187+
self.tokenizer.extended_id_to_piece.get(
1188+
index, self.tokenizer.id_to_piece[index]
1189+
)
1190+
for index in token_ids
1191+
]
1192+
1193+
generation["tokens"] = tokens
1194+
generation["token_probs"] = [
1195+
math.log(prob) for prob in token_probs
1196+
]
1197+
1198+
# Calculate the offset of each token in the output,
1199+
# working backwards from the end.
1200+
offsets = []
1201+
token_offset = 0
1202+
for token in tokens:
1203+
token_offset += len(token)
1204+
offsets.append(len(full_response) - token_offset)
1205+
offsets.reverse()
1206+
generation["offset"] = offsets
11721207

11731208
yield generation
11741209

endpoints/OAI/types/chat_completion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
class ChatCompletionLogprob(BaseModel):
1010
token: str
1111
logprob: float
12+
13+
14+
class ChatCompletionLogprobChoice(ChatCompletionLogprob):
1215
top_logprobs: Optional[List["ChatCompletionLogprob"]] = None
1316

1417

1518
class ChatCompletionLogprobs(BaseModel):
16-
content: List[ChatCompletionLogprob] = Field(default_factory=list)
19+
content: List[ChatCompletionLogprobChoice] = Field(default_factory=list)
1720

1821

1922
class ChatCompletionMessage(BaseModel):

endpoints/OAI/utils/chat_completion.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from endpoints.OAI.types.chat_completion import (
2323
ChatCompletionLogprobs,
2424
ChatCompletionLogprob,
25+
ChatCompletionLogprobChoice,
2526
ChatCompletionMessage,
2627
ChatCompletionRequest,
2728
ChatCompletionRespChoice,
@@ -46,22 +47,24 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
4647

4748
logprob_response = None
4849

49-
token_probs = unwrap(generation.get("token_probs"), {})
50+
tokens = unwrap(generation.get("tokens"), [])
51+
token_probs = unwrap(generation.get("token_probs"), [])
52+
logprobs = unwrap(generation.get("logprobs"), [])
5053
if token_probs:
51-
logprobs = unwrap(generation.get("logprobs"), [])
52-
5354
collected_token_probs = []
54-
for index, token in enumerate(token_probs.keys()):
55-
top_logprobs = [
56-
ChatCompletionLogprob(token=token, logprob=logprob)
57-
for token, logprob in logprobs[index].items()
55+
for output_token, token_logprob, top_logprobs in zip(
56+
tokens, token_probs, logprobs, strict=True
57+
):
58+
completion_logprobs = [
59+
ChatCompletionLogprob(token=token, logprob=token_logprob)
60+
for token, token_logprob in top_logprobs.items()
5861
]
5962

6063
collected_token_probs.append(
61-
ChatCompletionLogprob(
62-
token=token,
63-
logprob=token_probs[token],
64-
top_logprobs=top_logprobs,
64+
ChatCompletionLogprobChoice(
65+
token=output_token,
66+
logprob=token_logprob,
67+
top_logprobs=completion_logprobs,
6568
)
6669
)
6770

endpoints/OAI/utils/completion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,16 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "")
3535
for index, generation in enumerate(generations):
3636
logprob_response = None
3737

38-
token_probs = unwrap(generation.get("token_probs"), {})
38+
tokens = unwrap(generation.get("tokens"), [])
39+
token_probs = unwrap(generation.get("token_probs"), [])
3940
if token_probs:
4041
logprobs = unwrap(generation.get("logprobs"), [])
4142
offset = unwrap(generation.get("offset"), [])
4243

4344
logprob_response = CompletionLogProbs(
4445
text_offset=offset if isinstance(offset, list) else [offset],
45-
token_logprobs=token_probs.values(),
46-
tokens=token_probs.keys(),
46+
token_logprobs=token_probs,
47+
tokens=tokens,
4748
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
4849
)
4950

0 commit comments

Comments
 (0)