|
5 | 5 | import pickle |
6 | 6 | import os |
7 | 7 |
|
8 | | -# 🎨 Set Streamlit theme (for more color) |
| 8 | +# 🎨 Set Streamlit theme (for more color and wide layout) |
9 | 9 | st.set_page_config( |
10 | 10 | page_title="Customer Churn Prediction", |
11 | 11 | page_icon="🔮", |
12 | | - layout="centered", |
13 | | - initial_sidebar_state="auto" |
| 12 | + layout="wide", |
| 13 | + initial_sidebar_state="expanded" |
14 | 14 | ) |
15 | | -# Background with gradient using markdown (limited colors in Streamlit natively) |
| 15 | + |
| 16 | +# ----------------- PATHS ----------------- |
| 17 | +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| 18 | +MODELS_DIR = os.path.join(BASE_DIR, 'models') |
| 19 | +IMAGES_DIR = os.path.join(BASE_DIR, 'images') |
| 20 | + |
| 21 | +# ----------------- CUSTOM CSS ----------------- |
16 | 22 | st.markdown( |
17 | 23 | """ |
18 | 24 | <style> |
| 25 | + /* Main Background */ |
19 | 26 | .stApp { |
20 | | - background: linear-gradient(135deg, #f8ffae 0%, #43cea2 100%); |
| 27 | + background: linear-gradient(120deg, #fdfbfb 0%, #ebedee 100%); |
21 | 28 | } |
22 | | - .highlight { |
23 | | - padding: 0.5em 1em; |
24 | | - border-radius: 0.5em; |
25 | | - margin: 1em 0; |
| 29 | + |
| 30 | + /* Sidebar Styling */ |
| 31 | + section[data-testid="stSidebar"] { |
| 32 | + background-color: #ffffff; |
| 33 | + background-image: linear-gradient(315deg, #ffffff 0%, #d7e1ec 74%); |
26 | 34 | } |
27 | | - .info-hl { background: #d0f9ff; } |
28 | | - .warn-hl { background: #ffcccc; } |
29 | | - .succ-hl { background: #dcffe4; } |
30 | | - .predict-prob { |
31 | | - font-size: 2em; |
32 | | - font-weight: bold; |
33 | | - color: #fa8231; |
| 35 | +
|
| 36 | + /* Card Styling */ |
| 37 | + .css-1r6slb0, .css-12oz5g7 { |
| 38 | + padding: 1rem; |
| 39 | + border-radius: 10px; |
| 40 | + background: white; |
| 41 | + box-shadow: 0 4px 6px rgba(0,0,0,0.1); |
| 42 | + } |
| 43 | +
|
| 44 | + /* Headers */ |
| 45 | + h1, h2, h3 { |
| 46 | + color: #2c3e50; |
| 47 | + font-family: 'Helvetica Neue', sans-serif; |
| 48 | + } |
| 49 | + |
| 50 | + .main-header { |
| 51 | + text-align: center; |
| 52 | + background: -webkit-linear-gradient(#8e44ad, #3498db); |
| 53 | + -webkit-background-clip: text; |
| 54 | + -webkit-text-fill-color: transparent; |
| 55 | + font-size: 3rem; |
| 56 | + font-weight: 800; |
| 57 | + margin-bottom: 0.5rem; |
| 58 | + } |
| 59 | +
|
| 60 | + /* Custom classes */ |
| 61 | + .card { |
| 62 | + background: rgba(255, 255, 255, 0.9); |
| 63 | + padding: 20px; |
| 64 | + border-radius: 15px; |
| 65 | + box-shadow: 0 4px 15px rgba(0,0,0,0.1); |
| 66 | + margin-bottom: 20px; |
| 67 | + border-left: 5px solid #6c5ce7; |
| 68 | + } |
| 69 | + .result-card { |
| 70 | + background: #ffeaa7; |
| 71 | + padding: 20px; |
| 72 | + border-radius: 15px; |
| 73 | + text-align: center; |
| 74 | + border: 2px dashed #fdcb6e; |
| 75 | + } |
| 76 | + .safe { |
| 77 | + background: #d4edda; |
| 78 | + color: #155724; |
| 79 | + border-color: #c3e6cb; |
| 80 | + } |
| 81 | + .danger { |
| 82 | + background: #f8d7da; |
| 83 | + color: #721c24; |
| 84 | + border-color: #f5c6cb; |
34 | 85 | } |
35 | 86 | </style> |
36 | 87 | """, |
37 | 88 | unsafe_allow_html=True |
38 | 89 | ) |
39 | 90 |
|
40 | | -BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
41 | | -MODELS_DIR = os.path.join(BASE_DIR, 'models') |
42 | | - |
43 | | -model = tf.keras.models.load_model(os.path.join(MODELS_DIR, "model.h5")) |
44 | | - |
45 | | -# Load the encoders and scaler |
46 | | -with open(os.path.join(MODELS_DIR, 'label_encoder_gender.pkl'), 'rb') as file: |
47 | | - label_encoder_gender = pickle.load(file) |
| 91 | +# ----------------- LOAD RESOURCES ----------------- |
| 92 | +@st.cache_resource |
| 93 | +def load_all_models(): |
| 94 | + try: |
| 95 | + model_loaded = tf.keras.models.load_model(os.path.join(MODELS_DIR, "model.h5")) |
| 96 | + |
| 97 | + with open(os.path.join(MODELS_DIR, 'label_encoder_gender.pkl'), 'rb') as file: |
| 98 | + le_gender = pickle.load(file) |
| 99 | + |
| 100 | + with open(os.path.join(MODELS_DIR, 'onehot_encoder_geo.pkl'), 'rb') as file: |
| 101 | + ohe_geo = pickle.load(file) |
| 102 | + |
| 103 | + with open(os.path.join(MODELS_DIR, 'scaler.pkl'), 'rb') as file: |
| 104 | + s_scaler = pickle.load(file) |
| 105 | + |
| 106 | + return model_loaded, le_gender, ohe_geo, s_scaler |
| 107 | + except Exception as e: |
| 108 | + st.error(f"Error loading models: {e}") |
| 109 | + return None, None, None, None |
48 | 110 |
|
49 | | -with open(os.path.join(MODELS_DIR, 'onehot_encoder_geo.pkl'), 'rb') as file: |
50 | | - onehot_encoder_geo = pickle.load(file) |
| 111 | +model, label_encoder_gender, onehot_encoder_geo, scaler = load_all_models() |
51 | 112 |
|
52 | | -with open(os.path.join(MODELS_DIR, 'scaler.pkl'), 'rb') as file: |
53 | | - scaler = pickle.load(file) |
| 113 | +# ----------------- NAVIGATION ----------------- |
| 114 | +with st.sidebar: |
| 115 | + st.image("https://cdn-icons-png.flaticon.com/512/4144/4144517.png", width=100) |
| 116 | + st.title("Navigation") |
| 117 | + page = st.radio("Go to:", ["🔮 Predict Churn", "🧩 Model Architecture"]) |
| 118 | + |
| 119 | + st.markdown("---") |
| 120 | + st.info("Built with Neural Networks & Streamlit") |
54 | 121 |
|
55 | | -st.markdown( |
56 | | - """ |
57 | | - <div style='text-align:center;'> |
58 | | - <h1 style='color:#3b6978; font-size:2.5em; margin-bottom:0.2em;'>🔮 Customer Churn Prediction 🔮</h1> |
59 | | - <p style='color:#204051; font-size:1.2em;'> |
60 | | - Empower your business with colorful, instant predictions! |
61 | | - </p> |
62 | | - </div> |
63 | | - """, |
64 | | - unsafe_allow_html=True, |
65 | | -) |
| 122 | +# ----------------- PAGE: PREDICTION ----------------- |
| 123 | +if page == "🔮 Predict Churn": |
| 124 | + st.markdown("<h1 class='main-header'>Customer Churn Predictor</h1>", unsafe_allow_html=True) |
| 125 | + st.markdown("<p style='text-align:center; color:#7f8c8d; font-size:1.2em;'>Enter customer details below to predict the likelihood of churning.</p>", unsafe_allow_html=True) |
66 | 126 |
|
67 | | -# Add a colored horizontal rule for aesthetics |
68 | | -st.markdown("<hr style='border-top: 2px dotted #00b894;'>", unsafe_allow_html=True) |
| 127 | + if model is not None: |
| 128 | + st.markdown("<div class='card'>", unsafe_allow_html=True) |
| 129 | + with st.form("churn_prediction_form"): |
| 130 | + st.markdown("### 📝 Customer Profile") |
| 131 | + |
| 132 | + # Layout: 3 columns for better spacing (CRT column fix) |
| 133 | + col1, col2, col3 = st.columns(3) |
| 134 | + |
| 135 | + with col1: |
| 136 | + st.markdown("**Demographics**") |
| 137 | + geography = st.selectbox('🌎 Geography', onehot_encoder_geo.categories_[0]) |
| 138 | + gender = st.selectbox('🚻 Gender', label_encoder_gender.classes_) |
| 139 | + age = st.slider('🎂 Age', 18, 92, 30) |
| 140 | + |
| 141 | + with col2: |
| 142 | + st.markdown("**Financials**") |
| 143 | + credit_score = st.number_input('💳 Credit Score', min_value=0, max_value=1000, value=650) |
| 144 | + balance = st.number_input('🏦 Balance', min_value=0.0, value=10000.0) |
| 145 | + estimated_salary = st.number_input('💰 Salary', min_value=0.0, value=50000.0) |
| 146 | + |
| 147 | + with col3: |
| 148 | + st.markdown("**Account Details**") |
| 149 | + tenure = st.slider('⌛ Tenure (Years)', 0, 10, 3) |
| 150 | + num_of_products = st.slider('🛒 Products', 1, 4, 1) |
| 151 | + has_cr_card = st.radio('💳 Credit Card?', ['Yes', 'No'], horizontal=True) |
| 152 | + is_active_member = st.radio('🟢 Active Member?', ['Yes', 'No'], horizontal=True) |
69 | 153 |
|
70 | | -with st.form("churn_prediction_form"): |
71 | | - st.markdown("<div class='highlight info-hl'>Please fill out the customer details:</div>", unsafe_allow_html=True) |
72 | | - c1, c2 = st.columns(2) |
73 | | - with c1: |
74 | | - geography = st.selectbox('🌎 Geography', onehot_encoder_geo.categories_[0]) |
75 | | - gender = st.selectbox('🚻 Gender', label_encoder_gender.classes_) |
76 | | - age = st.slider('🎂 Age', 18, 92, 30) |
77 | | - credit_score = st.number_input('💳 Credit Score', min_value=0, max_value=1000, value=650) |
78 | | - tenure = st.slider('⌛ Tenure (years)', 0, 10, 3) |
79 | | - with c2: |
80 | | - balance = st.number_input('🏦 Balance', min_value=0.0, value=10000.0) |
81 | | - estimated_salary = st.number_input('💰 Estimated Salary', min_value=0.0, value=50000.0) |
82 | | - num_of_products = st.slider('🛒 Number of Products', 1, 4, 1) |
83 | | - has_cr_card = st.selectbox('💳 Has Credit Card', ['No', 'Yes']) |
84 | | - is_active_member = st.selectbox('🟢 Is Active Member', ['No', 'Yes']) |
| 154 | + # Map Yes/No |
| 155 | + has_cr_card_val = 1 if has_cr_card == 'Yes' else 0 |
| 156 | + is_active_member_val = 1 if is_active_member == 'Yes' else 0 |
| 157 | + |
| 158 | + st.markdown("---") |
| 159 | + submitted = st.form_submit_button("🚀 Run Prediction", use_container_width=True, type="primary") |
| 160 | + st.markdown("</div>", unsafe_allow_html=True) |
85 | 161 |
|
86 | | - # Map Yes/No to 1/0 |
87 | | - has_cr_card_val = 1 if has_cr_card == 'Yes' else 0 |
88 | | - is_active_member_val = 1 if is_active_member == 'Yes' else 0 |
89 | | - |
90 | | - # Rainbow button! |
91 | | - submitted = st.form_submit_button( |
92 | | - "🌈 Predict Churn 🌈", |
93 | | - use_container_width=True |
94 | | - ) |
| 162 | + if submitted: |
| 163 | + # Prepare Input |
| 164 | + input_data = pd.DataFrame({ |
| 165 | + 'CreditScore': [credit_score], |
| 166 | + 'Gender': [label_encoder_gender.transform([gender])[0]], |
| 167 | + 'Age': [age], |
| 168 | + 'Tenure': [tenure], |
| 169 | + 'Balance': [balance], |
| 170 | + 'NumOfProducts': [num_of_products], |
| 171 | + 'HasCrCard': [has_cr_card_val], |
| 172 | + 'IsActiveMember': [is_active_member_val], |
| 173 | + 'EstimatedSalary': [estimated_salary] |
| 174 | + }) |
| 175 | + |
| 176 | + # Encode Geo |
| 177 | + geo_encoded = onehot_encoder_geo.transform([[geography]]).toarray() |
| 178 | + geo_encoded_df = pd.DataFrame(geo_encoded, columns=onehot_encoder_geo.get_feature_names_out(['Geography'])) |
| 179 | + |
| 180 | + # Combine |
| 181 | + input_data = pd.concat([input_data.reset_index(drop=True), geo_encoded_df], axis=1) |
| 182 | + |
| 183 | + # Scale |
| 184 | + input_data_scaled = scaler.transform(input_data) |
| 185 | + |
| 186 | + # Predict |
| 187 | + with st.spinner('Thinking...'): |
| 188 | + prediction = model.predict(input_data_scaled, verbose=0) |
| 189 | + prediction_proba = prediction[0][0] |
| 190 | + |
| 191 | + # Display Results |
| 192 | + st.markdown("### 📊 Prediction Results") |
| 193 | + |
| 194 | + col_res1, col_res2 = st.columns([1, 2]) |
| 195 | + |
| 196 | + with col_res1: |
| 197 | + st.metric(label="Churn Probability", value=f"{prediction_proba:.2%}") |
| 198 | + |
| 199 | + with col_res2: |
| 200 | + if prediction_proba > 0.5: |
| 201 | + st.markdown( |
| 202 | + f""" |
| 203 | + <div class='result-card danger'> |
| 204 | + <h2>⚠️ High Risk of Churn</h2> |
| 205 | + <p>This customer has a <b>{prediction_proba:.2%}</b> probability of leaving.</p> |
| 206 | + </div> |
| 207 | + """, |
| 208 | + unsafe_allow_html=True |
| 209 | + ) |
| 210 | + else: |
| 211 | + st.markdown( |
| 212 | + f""" |
| 213 | + <div class='result-card safe'> |
| 214 | + <h2>✅ Low Risk of Churn</h2> |
| 215 | + <p>This customer is likely to stay (Probability: <b>{prediction_proba:.2%}</b>).</p> |
| 216 | + </div> |
| 217 | + """, |
| 218 | + unsafe_allow_html=True |
| 219 | + ) |
95 | 220 |
|
96 | | -if submitted: |
97 | | - # Prepare input data |
98 | | - input_data = pd.DataFrame({ |
99 | | - 'CreditScore': [credit_score], |
100 | | - 'Gender': [label_encoder_gender.transform([gender])[0]], |
101 | | - 'Age': [age], |
102 | | - 'Tenure': [tenure], |
103 | | - 'Balance': [balance], |
104 | | - 'NumOfProducts': [num_of_products], |
105 | | - 'HasCrCard': [has_cr_card_val], |
106 | | - 'IsActiveMember': [is_active_member_val], |
107 | | - 'EstimatedSalary': [estimated_salary] |
108 | | - }) |
109 | | - |
110 | | - # One-hot encode 'Geography' |
111 | | - geo_encoded = onehot_encoder_geo.transform([[geography]]).toarray() |
112 | | - geo_encoded_df = pd.DataFrame(geo_encoded, columns=onehot_encoder_geo.get_feature_names_out(['Geography'])) |
| 221 | +# ----------------- PAGE: DIAGRAM ----------------- |
| 222 | +elif page == "🧩 Model Architecture": |
| 223 | + st.markdown("<h1 class='main-header'>Model Architecture</h1>", unsafe_allow_html=True) |
| 224 | + st.markdown("<p style='text-align:center;'>Visual representation of the Neural Network and Data Pipeline.</p>", unsafe_allow_html=True) |
113 | 225 |
|
114 | | - # Combine one-hot encoded columns with input data |
115 | | - input_data = pd.concat([input_data.reset_index(drop=True), geo_encoded_df], axis=1) |
| 226 | + tab1, tab2 = st.tabs(["Simple View", "Detailed View"]) |
116 | 227 |
|
117 | | - # Scale the input data |
118 | | - input_data_scaled = scaler.transform(input_data) |
119 | | - |
120 | | - # Predict churn |
121 | | - with st.spinner('✨ Calculating your colorful prediction...'): |
122 | | - prediction = model.predict(input_data_scaled, verbose=0) |
123 | | - prediction_proba = prediction[0][0] |
124 | | - |
125 | | - # Show result with custom coloring and emoji |
126 | | - st.markdown( |
127 | | - "<div class='highlight succ-hl'>🔥 <b>Prediction Complete!</b> 🔥</div>", |
128 | | - unsafe_allow_html=True |
129 | | - ) |
130 | | - st.markdown( |
131 | | - f"<div class='predict-prob'>Churn Probability: <span style='color:#0984e3'>{prediction_proba:.2%}</span></div>", |
132 | | - unsafe_allow_html=True |
133 | | - ) |
134 | | - |
135 | | - if prediction_proba > 0.5: |
136 | | - st.markdown( |
137 | | - "<div class='highlight warn-hl'>" |
138 | | - "⚠️ <span style='color:#d63031; font-weight:bold;'>The customer is <u>likely</u> to churn.</span> " |
139 | | - "Take action now! 🚨" |
140 | | - "</div>", |
141 | | - unsafe_allow_html=True |
142 | | - ) |
143 | | - else: |
144 | | - st.markdown( |
145 | | - "<div class='highlight succ-hl'>" |
146 | | - "✅ <span style='color:#00b894; font-weight:bold;'>The customer is <u>not likely</u> to churn.</span> " |
147 | | - "Keep engaging! 🎉" |
148 | | - "</div>", |
149 | | - unsafe_allow_html=True |
150 | | - ) |
151 | | -# End with a nice colorful footer |
| 228 | + with tab1: |
| 229 | + img_path = os.path.join(IMAGES_DIR, "pipeline_diagram.png") |
| 230 | + if os.path.exists(img_path): |
| 231 | + st.image(img_path, caption="High-Level Pipeline", use_column_width=True) |
| 232 | + else: |
| 233 | + st.warning("Simple diagram not found.") |
| 234 | + |
| 235 | + with tab2: |
| 236 | + img_path_det = os.path.join(IMAGES_DIR, "pipeline_diagram_detailed.png") |
| 237 | + if os.path.exists(img_path_det): |
| 238 | + st.image(img_path_det, caption="Detailed Architecture", use_column_width=True) |
| 239 | + else: |
| 240 | + st.warning("Detailed diagram not found.") |
| 241 | + |
| 242 | +# Footer |
152 | 243 | st.markdown( |
153 | 244 | """ |
154 | | - <hr style='border-top:2px solid #fdcb6e;'> |
155 | | - <div style="text-align:center;font-size:1em;color:#636e72;"> |
156 | | - Made with <span style="color:#fd79a8;">❤</span> and <span style="color:#00b894;">Streamlit</span> |
| 245 | + <div style="text-align:center; margin-top: 50px; color: #b2bec3;"> |
| 246 | + <hr> |
| 247 | + <small>Customer Churn Prediction App | v2.0 | Powered by TensorFlow</small> |
157 | 248 | </div> |
158 | 249 | """, |
159 | 250 | unsafe_allow_html=True |
|
0 commit comments