-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
141 lines (116 loc) Β· 4.77 KB
/
app.py
File metadata and controls
141 lines (116 loc) Β· 4.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import streamlit as st
import asyncio
import PyPDF2
from docx import Document # Correct import for python-docx
from groq import AsyncGroq
import json
import time
import os
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# --- Setup ---
def extract_text_from_pdf(pdf_path):
with open(pdf_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
return ''.join([page.extract_text() for page in reader.pages])
def extract_text_from_docx(docx_path):
doc = Document(docx_path) # Use Document from python-docx
full_text = []
for para in doc.paragraphs:
full_text.append(para.text)
return '\n'.join(full_text)
def load_knowledge_cache(file_path, max_length=6000):
file_extension = file_path.split('.')[-1].lower()
try:
if file_extension == 'pdf':
return extract_text_from_pdf(file_path)[:max_length]
elif file_extension == 'docx':
return extract_text_from_docx(file_path)[:max_length]
else:
st.error("Unsupported file type. Please upload a PDF or DOCX file.")
return None
except Exception as e:
st.error(f"An error occurred while processing the file: {str(e)}")
return None
# --- Constants ---
MODEL_OPTIONS = {
"qwen-qwq-32b": "qwen-qwq-32b",
"llama-3.1-8b-instant": "llama-3.1-8b-instant",
"gemma2-9b-it": "gemma2-9b-it",
}
# --- Initialize State ---
if "chat_history" not in st.session_state:
st.session_state["chat_history"] = []
st.title("CAG - Cache Augmented Generation")
# --- Sidebar ---
with st.sidebar:
st.header("CAG Settings")
api_key = st.text_input("Groq API Key", type="password")
if not api_key:
st.info("Please enter your Groq API key to continue.")
st.stop()
model_name = st.selectbox("Choose Model", options=list(MODEL_OPTIONS.keys()), index=0)
uploaded_file = st.file_uploader("Upload PDF or DOCX file", type=["pdf", "docx"])
if uploaded_file is None:
st.info("Please upload a PDF or DOCX file to proceed.")
st.stop()
st.download_button(
label="Download Chat History",
data=json.dumps({"model": st.session_state["chat_history"]}, indent=4),
file_name="history.json",
mime="application/json",
)
# --- Groq API Functions ---
async def generate_response_async(prompt, groq_api_key, model):
async with AsyncGroq(api_key=groq_api_key) as client:
chat_completion = await client.chat.completions.create(
messages=[{"role": "user", "content": prompt}],
model=model,
max_tokens=1024,
)
return chat_completion.choices[0].message.content.strip()
async def augmented_generation(question: str, knowledge_cache: str, groq_api_key: str, model: str) -> str:
prompt = f"""KNOWLEDGE BASE:
{knowledge_cache}
QUESTION: {question}
ANSWER:"""
response = await generate_response_async(prompt, groq_api_key, model)
return response
# Load or initialize knowledge cache once
if "knowledge_cache" not in st.session_state:
with st.spinner('Loading knowledge base...'):
try:
# Use the uploaded file name
st.session_state["knowledge_cache"] = load_knowledge_cache(uploaded_file.name)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.stop()
# Display chat messages from history
for message in st.session_state["chat_history"]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if prompt := st.chat_input("Ask a question:"):
# Add user message to chat history
st.session_state["chat_history"].append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Get model response
start_time = time.time()
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
try:
full_response = asyncio.run(augmented_generation(prompt, st.session_state["knowledge_cache"], api_key, MODEL_OPTIONS[model_name]))
except Exception as e:
st.error(f"An error occurred: {str(e)}")
full_response = "Sorry, an error occurred while processing your request."
message_placeholder.markdown(full_response + "β")
message_placeholder.markdown(full_response)
end_time = time.time()
response_time = end_time - start_time
st.write(f"Response Time: {response_time:.2f} seconds")
# Add assistant response to chat history
st.session_state["chat_history"].append({"role": "assistant", "content": full_response})