Skip to content

Commit 35a38fe

Browse files
updated design
1 parent a4c8f03 commit 35a38fe

1 file changed

Lines changed: 213 additions & 122 deletions

File tree

src/app.py

Lines changed: 213 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -5,155 +5,246 @@
55
import pickle
66
import os
77

8-
# 🎨 Set Streamlit theme (for more color)
8+
# 🎨 Set Streamlit theme (for more color and wide layout)
99
st.set_page_config(
1010
page_title="Customer Churn Prediction",
1111
page_icon="🔮",
12-
layout="centered",
13-
initial_sidebar_state="auto"
12+
layout="wide",
13+
initial_sidebar_state="expanded"
1414
)
15-
# Background with gradient using markdown (limited colors in Streamlit natively)
15+
16+
# ----------------- PATHS -----------------
17+
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
18+
MODELS_DIR = os.path.join(BASE_DIR, 'models')
19+
IMAGES_DIR = os.path.join(BASE_DIR, 'images')
20+
21+
# ----------------- CUSTOM CSS -----------------
1622
st.markdown(
1723
"""
1824
<style>
25+
/* Main Background */
1926
.stApp {
20-
background: linear-gradient(135deg, #f8ffae 0%, #43cea2 100%);
27+
background: linear-gradient(120deg, #fdfbfb 0%, #ebedee 100%);
2128
}
22-
.highlight {
23-
padding: 0.5em 1em;
24-
border-radius: 0.5em;
25-
margin: 1em 0;
29+
30+
/* Sidebar Styling */
31+
section[data-testid="stSidebar"] {
32+
background-color: #ffffff;
33+
background-image: linear-gradient(315deg, #ffffff 0%, #d7e1ec 74%);
2634
}
27-
.info-hl { background: #d0f9ff; }
28-
.warn-hl { background: #ffcccc; }
29-
.succ-hl { background: #dcffe4; }
30-
.predict-prob {
31-
font-size: 2em;
32-
font-weight: bold;
33-
color: #fa8231;
35+
36+
/* Card Styling */
37+
.css-1r6slb0, .css-12oz5g7 {
38+
padding: 1rem;
39+
border-radius: 10px;
40+
background: white;
41+
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
42+
}
43+
44+
/* Headers */
45+
h1, h2, h3 {
46+
color: #2c3e50;
47+
font-family: 'Helvetica Neue', sans-serif;
48+
}
49+
50+
.main-header {
51+
text-align: center;
52+
background: -webkit-linear-gradient(#8e44ad, #3498db);
53+
-webkit-background-clip: text;
54+
-webkit-text-fill-color: transparent;
55+
font-size: 3rem;
56+
font-weight: 800;
57+
margin-bottom: 0.5rem;
58+
}
59+
60+
/* Custom classes */
61+
.card {
62+
background: rgba(255, 255, 255, 0.9);
63+
padding: 20px;
64+
border-radius: 15px;
65+
box-shadow: 0 4px 15px rgba(0,0,0,0.1);
66+
margin-bottom: 20px;
67+
border-left: 5px solid #6c5ce7;
68+
}
69+
.result-card {
70+
background: #ffeaa7;
71+
padding: 20px;
72+
border-radius: 15px;
73+
text-align: center;
74+
border: 2px dashed #fdcb6e;
75+
}
76+
.safe {
77+
background: #d4edda;
78+
color: #155724;
79+
border-color: #c3e6cb;
80+
}
81+
.danger {
82+
background: #f8d7da;
83+
color: #721c24;
84+
border-color: #f5c6cb;
3485
}
3586
</style>
3687
""",
3788
unsafe_allow_html=True
3889
)
3990

40-
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
41-
MODELS_DIR = os.path.join(BASE_DIR, 'models')
42-
43-
model = tf.keras.models.load_model(os.path.join(MODELS_DIR, "model.h5"))
44-
45-
# Load the encoders and scaler
46-
with open(os.path.join(MODELS_DIR, 'label_encoder_gender.pkl'), 'rb') as file:
47-
label_encoder_gender = pickle.load(file)
91+
# ----------------- LOAD RESOURCES -----------------
92+
@st.cache_resource
93+
def load_all_models():
94+
try:
95+
model_loaded = tf.keras.models.load_model(os.path.join(MODELS_DIR, "model.h5"))
96+
97+
with open(os.path.join(MODELS_DIR, 'label_encoder_gender.pkl'), 'rb') as file:
98+
le_gender = pickle.load(file)
99+
100+
with open(os.path.join(MODELS_DIR, 'onehot_encoder_geo.pkl'), 'rb') as file:
101+
ohe_geo = pickle.load(file)
102+
103+
with open(os.path.join(MODELS_DIR, 'scaler.pkl'), 'rb') as file:
104+
s_scaler = pickle.load(file)
105+
106+
return model_loaded, le_gender, ohe_geo, s_scaler
107+
except Exception as e:
108+
st.error(f"Error loading models: {e}")
109+
return None, None, None, None
48110

49-
with open(os.path.join(MODELS_DIR, 'onehot_encoder_geo.pkl'), 'rb') as file:
50-
onehot_encoder_geo = pickle.load(file)
111+
model, label_encoder_gender, onehot_encoder_geo, scaler = load_all_models()
51112

52-
with open(os.path.join(MODELS_DIR, 'scaler.pkl'), 'rb') as file:
53-
scaler = pickle.load(file)
113+
# ----------------- NAVIGATION -----------------
114+
with st.sidebar:
115+
st.image("https://cdn-icons-png.flaticon.com/512/4144/4144517.png", width=100)
116+
st.title("Navigation")
117+
page = st.radio("Go to:", ["🔮 Predict Churn", "🧩 Model Architecture"])
118+
119+
st.markdown("---")
120+
st.info("Built with Neural Networks & Streamlit")
54121

55-
st.markdown(
56-
"""
57-
<div style='text-align:center;'>
58-
<h1 style='color:#3b6978; font-size:2.5em; margin-bottom:0.2em;'>🔮 Customer Churn Prediction 🔮</h1>
59-
<p style='color:#204051; font-size:1.2em;'>
60-
Empower your business with colorful, instant predictions!
61-
</p>
62-
</div>
63-
""",
64-
unsafe_allow_html=True,
65-
)
122+
# ----------------- PAGE: PREDICTION -----------------
123+
if page == "🔮 Predict Churn":
124+
st.markdown("<h1 class='main-header'>Customer Churn Predictor</h1>", unsafe_allow_html=True)
125+
st.markdown("<p style='text-align:center; color:#7f8c8d; font-size:1.2em;'>Enter customer details below to predict the likelihood of churning.</p>", unsafe_allow_html=True)
66126

67-
# Add a colored horizontal rule for aesthetics
68-
st.markdown("<hr style='border-top: 2px dotted #00b894;'>", unsafe_allow_html=True)
127+
if model is not None:
128+
st.markdown("<div class='card'>", unsafe_allow_html=True)
129+
with st.form("churn_prediction_form"):
130+
st.markdown("### 📝 Customer Profile")
131+
132+
# Layout: 3 columns for better spacing (CRT column fix)
133+
col1, col2, col3 = st.columns(3)
134+
135+
with col1:
136+
st.markdown("**Demographics**")
137+
geography = st.selectbox('🌎 Geography', onehot_encoder_geo.categories_[0])
138+
gender = st.selectbox('🚻 Gender', label_encoder_gender.classes_)
139+
age = st.slider('🎂 Age', 18, 92, 30)
140+
141+
with col2:
142+
st.markdown("**Financials**")
143+
credit_score = st.number_input('💳 Credit Score', min_value=0, max_value=1000, value=650)
144+
balance = st.number_input('🏦 Balance', min_value=0.0, value=10000.0)
145+
estimated_salary = st.number_input('💰 Salary', min_value=0.0, value=50000.0)
146+
147+
with col3:
148+
st.markdown("**Account Details**")
149+
tenure = st.slider('⌛ Tenure (Years)', 0, 10, 3)
150+
num_of_products = st.slider('🛒 Products', 1, 4, 1)
151+
has_cr_card = st.radio('💳 Credit Card?', ['Yes', 'No'], horizontal=True)
152+
is_active_member = st.radio('🟢 Active Member?', ['Yes', 'No'], horizontal=True)
69153

70-
with st.form("churn_prediction_form"):
71-
st.markdown("<div class='highlight info-hl'>Please fill out the customer details:</div>", unsafe_allow_html=True)
72-
c1, c2 = st.columns(2)
73-
with c1:
74-
geography = st.selectbox('🌎 Geography', onehot_encoder_geo.categories_[0])
75-
gender = st.selectbox('🚻 Gender', label_encoder_gender.classes_)
76-
age = st.slider('🎂 Age', 18, 92, 30)
77-
credit_score = st.number_input('💳 Credit Score', min_value=0, max_value=1000, value=650)
78-
tenure = st.slider('⌛ Tenure (years)', 0, 10, 3)
79-
with c2:
80-
balance = st.number_input('🏦 Balance', min_value=0.0, value=10000.0)
81-
estimated_salary = st.number_input('💰 Estimated Salary', min_value=0.0, value=50000.0)
82-
num_of_products = st.slider('🛒 Number of Products', 1, 4, 1)
83-
has_cr_card = st.selectbox('💳 Has Credit Card', ['No', 'Yes'])
84-
is_active_member = st.selectbox('🟢 Is Active Member', ['No', 'Yes'])
154+
# Map Yes/No
155+
has_cr_card_val = 1 if has_cr_card == 'Yes' else 0
156+
is_active_member_val = 1 if is_active_member == 'Yes' else 0
157+
158+
st.markdown("---")
159+
submitted = st.form_submit_button("🚀 Run Prediction", use_container_width=True, type="primary")
160+
st.markdown("</div>", unsafe_allow_html=True)
85161

86-
# Map Yes/No to 1/0
87-
has_cr_card_val = 1 if has_cr_card == 'Yes' else 0
88-
is_active_member_val = 1 if is_active_member == 'Yes' else 0
89-
90-
# Rainbow button!
91-
submitted = st.form_submit_button(
92-
"🌈 Predict Churn 🌈",
93-
use_container_width=True
94-
)
162+
if submitted:
163+
# Prepare Input
164+
input_data = pd.DataFrame({
165+
'CreditScore': [credit_score],
166+
'Gender': [label_encoder_gender.transform([gender])[0]],
167+
'Age': [age],
168+
'Tenure': [tenure],
169+
'Balance': [balance],
170+
'NumOfProducts': [num_of_products],
171+
'HasCrCard': [has_cr_card_val],
172+
'IsActiveMember': [is_active_member_val],
173+
'EstimatedSalary': [estimated_salary]
174+
})
175+
176+
# Encode Geo
177+
geo_encoded = onehot_encoder_geo.transform([[geography]]).toarray()
178+
geo_encoded_df = pd.DataFrame(geo_encoded, columns=onehot_encoder_geo.get_feature_names_out(['Geography']))
179+
180+
# Combine
181+
input_data = pd.concat([input_data.reset_index(drop=True), geo_encoded_df], axis=1)
182+
183+
# Scale
184+
input_data_scaled = scaler.transform(input_data)
185+
186+
# Predict
187+
with st.spinner('Thinking...'):
188+
prediction = model.predict(input_data_scaled, verbose=0)
189+
prediction_proba = prediction[0][0]
190+
191+
# Display Results
192+
st.markdown("### 📊 Prediction Results")
193+
194+
col_res1, col_res2 = st.columns([1, 2])
195+
196+
with col_res1:
197+
st.metric(label="Churn Probability", value=f"{prediction_proba:.2%}")
198+
199+
with col_res2:
200+
if prediction_proba > 0.5:
201+
st.markdown(
202+
f"""
203+
<div class='result-card danger'>
204+
<h2>⚠️ High Risk of Churn</h2>
205+
<p>This customer has a <b>{prediction_proba:.2%}</b> probability of leaving.</p>
206+
</div>
207+
""",
208+
unsafe_allow_html=True
209+
)
210+
else:
211+
st.markdown(
212+
f"""
213+
<div class='result-card safe'>
214+
<h2>✅ Low Risk of Churn</h2>
215+
<p>This customer is likely to stay (Probability: <b>{prediction_proba:.2%}</b>).</p>
216+
</div>
217+
""",
218+
unsafe_allow_html=True
219+
)
95220

96-
if submitted:
97-
# Prepare input data
98-
input_data = pd.DataFrame({
99-
'CreditScore': [credit_score],
100-
'Gender': [label_encoder_gender.transform([gender])[0]],
101-
'Age': [age],
102-
'Tenure': [tenure],
103-
'Balance': [balance],
104-
'NumOfProducts': [num_of_products],
105-
'HasCrCard': [has_cr_card_val],
106-
'IsActiveMember': [is_active_member_val],
107-
'EstimatedSalary': [estimated_salary]
108-
})
109-
110-
# One-hot encode 'Geography'
111-
geo_encoded = onehot_encoder_geo.transform([[geography]]).toarray()
112-
geo_encoded_df = pd.DataFrame(geo_encoded, columns=onehot_encoder_geo.get_feature_names_out(['Geography']))
221+
# ----------------- PAGE: DIAGRAM -----------------
222+
elif page == "🧩 Model Architecture":
223+
st.markdown("<h1 class='main-header'>Model Architecture</h1>", unsafe_allow_html=True)
224+
st.markdown("<p style='text-align:center;'>Visual representation of the Neural Network and Data Pipeline.</p>", unsafe_allow_html=True)
113225

114-
# Combine one-hot encoded columns with input data
115-
input_data = pd.concat([input_data.reset_index(drop=True), geo_encoded_df], axis=1)
226+
tab1, tab2 = st.tabs(["Simple View", "Detailed View"])
116227

117-
# Scale the input data
118-
input_data_scaled = scaler.transform(input_data)
119-
120-
# Predict churn
121-
with st.spinner('✨ Calculating your colorful prediction...'):
122-
prediction = model.predict(input_data_scaled, verbose=0)
123-
prediction_proba = prediction[0][0]
124-
125-
# Show result with custom coloring and emoji
126-
st.markdown(
127-
"<div class='highlight succ-hl'>🔥 <b>Prediction Complete!</b> 🔥</div>",
128-
unsafe_allow_html=True
129-
)
130-
st.markdown(
131-
f"<div class='predict-prob'>Churn Probability: <span style='color:#0984e3'>{prediction_proba:.2%}</span></div>",
132-
unsafe_allow_html=True
133-
)
134-
135-
if prediction_proba > 0.5:
136-
st.markdown(
137-
"<div class='highlight warn-hl'>"
138-
"⚠️ <span style='color:#d63031; font-weight:bold;'>The customer is <u>likely</u> to churn.</span> "
139-
"Take action now! 🚨"
140-
"</div>",
141-
unsafe_allow_html=True
142-
)
143-
else:
144-
st.markdown(
145-
"<div class='highlight succ-hl'>"
146-
"✅ <span style='color:#00b894; font-weight:bold;'>The customer is <u>not likely</u> to churn.</span> "
147-
"Keep engaging! 🎉"
148-
"</div>",
149-
unsafe_allow_html=True
150-
)
151-
# End with a nice colorful footer
228+
with tab1:
229+
img_path = os.path.join(IMAGES_DIR, "pipeline_diagram.png")
230+
if os.path.exists(img_path):
231+
st.image(img_path, caption="High-Level Pipeline", use_column_width=True)
232+
else:
233+
st.warning("Simple diagram not found.")
234+
235+
with tab2:
236+
img_path_det = os.path.join(IMAGES_DIR, "pipeline_diagram_detailed.png")
237+
if os.path.exists(img_path_det):
238+
st.image(img_path_det, caption="Detailed Architecture", use_column_width=True)
239+
else:
240+
st.warning("Detailed diagram not found.")
241+
242+
# Footer
152243
st.markdown(
153244
"""
154-
<hr style='border-top:2px solid #fdcb6e;'>
155-
<div style="text-align:center;font-size:1em;color:#636e72;">
156-
Made with <span style="color:#fd79a8;">&#10084;</span> and <span style="color:#00b894;">Streamlit</span>
245+
<div style="text-align:center; margin-top: 50px; color: #b2bec3;">
246+
<hr>
247+
<small>Customer Churn Prediction App | v2.0 | Powered by TensorFlow</small>
157248
</div>
158249
""",
159250
unsafe_allow_html=True

0 commit comments

Comments
 (0)