-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdeploy_classifier.py
More file actions
40 lines (32 loc) · 1.25 KB
/
deploy_classifier.py
File metadata and controls
40 lines (32 loc) · 1.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformers import BertTokenizer, BertForSequenceClassification
import torch
# Load model
model_path = './bert_news_classifier'
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)
model.eval()
# Class names
class_names = ['World', 'Sports', 'Business', 'Sci/Tech']
def classify_news(text):
"""Classify news headline"""
inputs = tokenizer(text, padding='max_length', truncation=True,
max_length=128, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(predictions, dim=-1).item()
confidence = predictions[0][predicted_class].item()
return {
'category': class_names[predicted_class],
'confidence': confidence,
'all_probabilities': {
name: prob.item()
for name, prob in zip(class_names, predictions[0])
}
}
# Example usage
if __name__ == '__main__':
example = "NASA launches new Mars rover mission"
result = classify_news(example)
print(f"Text: {example}")
print(f"Category: {result['category']} ({result['confidence']:.2%})")