-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
50 lines (38 loc) · 1.3 KB
/
train.py
File metadata and controls
50 lines (38 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import tensorflow as tf
from model import build_model
from custom_data_loader import CustomDataLoader
from utils import plot_training_history
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
def main():
"""
Runs the main training loop for the ThoraxScanAI model.
Parameters:
- None
Returns:
- None
"""
# Setup paths
dataset_path = 'ThoraxScanData'
# Initialize the custom data loader
data_loader = CustomDataLoader(dataset_path=dataset_path)
train_ds, val_ds, test_ds = data_loader.get_data_loaders() # Get TensorFlow datasets
num_classes = 3 # Normal, Lung_Opacity, Viral Pneumonia
model = build_model(num_classes)
model.compile(optimizer=Adam(lr=1e-4), loss='categorical_crossentropy', metrics=['accuracy'])
# Callbacks
callbacks = [
ModelCheckpoint('best_model.h5', save_best_only=True),
EarlyStopping(patience=20),
ReduceLROnPlateau(factor=0.2, patience=5, min_lr=1e-3)
]
# Train the model
history = model.fit(
train_ds,
epochs=50,
validation_data=val_ds,
callbacks=callbacks)
# Plot the training history
plot_training_history(history)
if __name__ == '__main__':
main()