diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py index 683a78556..c21b8c36c 100644 --- a/fastchat/serve/model_worker.py +++ b/fastchat/serve/model_worker.py @@ -144,6 +144,7 @@ def generate_stream_gate(self, params): yield json.dumps(ret).encode() + b"\0" def generate_gate(self, params): + x = b"{}\0" for x in self.generate_stream_gate(params): pass return json.loads(x[:-1].decode()) @@ -171,7 +172,7 @@ def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): mask = attention_mask.unsqueeze(-1).expand(data.size()).float() masked_embeddings = data * mask sum_embeddings = torch.sum(masked_embeddings, dim=1) - token_num = torch.sum(attention_mask).item() + token_num = attention_mask.sum(dim=1, keepdim=True) return sum_embeddings, token_num @@ -224,7 +225,7 @@ def get_embeddings(self, params): ): embedding = embedding / token_num normalized_embeddings = F.normalize(embedding, p=2, dim=1) - ret["token_num"] = token_num + ret["token_num"] = token_num.sum().item() else: all_embeddings = [] all_token_num = 0 @@ -273,7 +274,7 @@ def get_embeddings(self, params): embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num normalized_embeddings = F.normalize(embedding, p=2, dim=1) - ret["token_num"] = all_token_num + ret["token_num"] = all_token_num.sum().item() if base64_encode == "base64": out_embeddings = self.__encode_base64(normalized_embeddings)