-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprocessor.py
More file actions
153 lines (128 loc) · 6.84 KB
/
processor.py
File metadata and controls
153 lines (128 loc) · 6.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
from typing import List
from pydantic import BaseModel, Field
from langchain_groq import ChatGroq
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities import SQLDatabase
from langchain_core.output_parsers import StrOutputParser # change to StrOutputParser instead of pydantic, for raw string output in DB queries
from dotenv import load_dotenv
import sqlparse
# python code generator with ---st.code(result["python"], language="python") --- in app.py to format the SQL query output nicely in Streamlit and use for next method to build
load_dotenv()
# Pydantic Schemas for Structured Output
class SummaryOutput(BaseModel):
"""Schema for Use Case 1: Summarization"""
bullet_points: List[str] = Field(description="Exactly 3 to 6 concise summary points.")
key_conclusion: str = Field(description="A single sentence representing the main takeaway.")
class TopicOutput(BaseModel):
"""Schema for Use Case 2: Topic Extraction"""
topics: List[str] = Field(description="3 to 7 distinct themes, each 1-3 words.")
explanation: str = Field(description="A brief explanation of how these topics were identified.")
class IntentOutput(BaseModel):
"""Schema for Use Case 3: Intent Classification"""
intent: str = Field(description="The category: Technical issue, Billing question, Feature request, Complaint, General inquiry, Uncategorized, or Ambiguous.")
confidence_score: float = Field(description="Confidence score between 0 and 1.")
reasoning: str = Field(description="The logic used to classify the user message.")
class TextAnalysisProcessor:
def __init__(self):
self.llm = ChatGroq(
temperature=0,
model_name="llama-3.3-70b-versatile",
groq_api_key=os.getenv("GROQ_API_KEY")
)
self.db = SQLDatabase.from_uri("sqlite:///student_grades.db")
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 500,
chunk_overlap = 50
)
def summarize(self, text: str) -> SummaryOutput:
"""Use Case 1: Summarization"""
structured_llm = self.llm.with_structured_output(SummaryOutput)
prompt = ChatPromptTemplate.from_template(
"Summarize the following text into 3-6 bullet points and provide a key conclusion. "
"Preserve main ideas and key facts.\n\nText: {text}"
)
chain = prompt | structured_llm
return chain.invoke({"text": text})
def synthesize_summaries(self, summaries: List[str]) -> SummaryOutput:
"""Synthesize multiple summary conclusions into one coherent summary (REDUCE step)"""
structured_llm = self.llm.with_structured_output(SummaryOutput)
combined = "\n".join(summaries)
prompt = ChatPromptTemplate.from_template(
"These are summaries of different document sections. "
"Synthesize them into one coherent summary with 3-6 bullet points and a single key conclusion. "
"Ensure the final summary captures the essence of all sections.\n\nSection Summaries:\n{text}"
)
chain = prompt | structured_llm
return chain.invoke({"text": combined})
def summarize2(self, text: str, depth: int = 0) -> SummaryOutput:
MAX_DEPTH = 2
if len(text) > 6000 and depth < MAX_DEPTH:
chunks = self.text_splitter.split_text(text)
# make partial summaries a list
partial_summaries = [self.summarize2(chunk, depth + 1 ).key_conclusion for chunk in chunks] # each partial summary is a pydantic object and the key conclussion is extracted from it
# REDUCE: Use dedicated synthesize method to properly combine summaries
return self.synthesize_summaries(partial_summaries)
# Base case: text is small enough or max depth reached
return self.summarize(text)
def extract_topics(self, text: str) -> TopicOutput:
"""Use Case 2: Topic Extraction"""
structured_llm = self.llm.with_structured_output(TopicOutput)
prompt = ChatPromptTemplate.from_template(
"Identify 3-7 key topics from the text. Each topic must be 1-3 words. "
"Ensure topics are meaningful and distinct.\n\nText: {text}"
)
chain = prompt | structured_llm
return chain.invoke({"text": text})
def classify_intent(self, message: str) -> IntentOutput:
"""Use Case 3: Intent Classification"""
structured_llm = self.llm.with_structured_output(IntentOutput)
prompt = ChatPromptTemplate.from_template(
"Classify the intent of the following user message: '{message}'.\n"
"Use labels: Technical issue, Billing question, Feature request, Complaint, General inquiry. "
"Label as 'Ambiguous' if unclear or 'Uncategorized' if it doesn't fit."
)
chain = prompt | structured_llm
return chain.invoke({"message": message})
def query_database(self, query: str) -> dict:
"""Execute an SQL query against the student_grades.db database.
Returns a dict with 'sql' and 'results' keys.
"""
try:
# Prompt (LCEL Style)
prompt = ChatPromptTemplate.from_template("""
You are a senior data analyst and SQL expert.
Given the database schema below, write a correct SQL query
that answers the user's question.
Rules:
- Use only the tables and columns in the schema
- Do NOT explain anything
- Return ONLY the SQL query
Schema:
{schema}
Question:
{question}
""")
# LCEL Runnable Pipeline
sql_chain = (
prompt
| self.llm
| StrOutputParser()
)
schema = self.db.get_table_info()
generated_sql = sql_chain.invoke({"schema": schema, "question": query})
# Clean up markdown formatting if present
generated_sql = generated_sql.strip()
if generated_sql.startswith("```"):
generated_sql = generated_sql.split("```")[1]
if generated_sql.startswith("sql"):
generated_sql = generated_sql[3:]
generated_sql = generated_sql.strip()
# Execute the generated SQL and return both SQL and results
# results = self.db.run(generated_sql)
results = self.db._execute(generated_sql)
return {"sql": generated_sql, "results": results}
except Exception as e:
return {"sql": generated_sql if 'generated_sql' in locals() else "", "results": f"Error executing database query: {str(e)}"}
# return {"sql": "", "results": f"Error executing database query: {str(e)}"}