This project implements a Decision Tree Classifier on the famous Iris dataset to classify iris plants into three species: Setosa, Versicolor, and Virginica. The project includes pre-pruning, hyperparameter tuning using GridSearchCV, and visualization of the decision tree.
The Iris dataset is a classic dataset in machine learning and pattern recognition. It contains 150 instances and 4 numeric features:
- Sepal length (cm)
- Sepal width (cm)
- Petal length (cm)
- Petal width (cm)
The dataset has 3 classes, each with 50 samples.
- Decision Tree Classifier implementation from
sklearn - Train/Test split to evaluate model performance
- Decision tree visualization for easy interpretation
- Hyperparameter tuning with
GridSearchCV - Evaluation metrics: Accuracy, Precision, Recall, F1-score, Confusion Matrix
- Clone the repository:
git clone <repository_url>
- Install dependencies:
pip install -r requirements.txt
# Import libraries
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import matplotlib.pyplot as plt
# Load dataset
iris = load_iris()
X = pd.DataFrame(iris['data'], columns=iris['feature_names'])
y = iris['target']
# Split dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=10)
# Train Decision Tree Classifier
model = DecisionTreeClassifier(max_depth=2)
model.fit(X_train, y_train)
# Visualize Decision Tree
plt.figure(figsize=(15,10))
plot_tree(model, filled=True)
# Evaluate model
y_pred = model.predict(X_test)
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))
print("Accuracy:", accuracy_score(y_test, y_pred))param_grid = {
'criterion': ['gini', 'entropy', 'log_loss'],
'splitter': ['best', 'random'],
'max_depth': [1, 2, 3, 4, 5],
'max_features': ['auto', 'sqrt', 'log2']
}
grid = GridSearchCV(DecisionTreeClassifier(), param_grid, cv=5, scoring='accuracy')
grid.fit(X_train, y_train)
print(grid.best_params_)
print("Best cross-validation score:", grid.best_score_)- Confusion Matrix:
[[10 0 0]
[ 0 11 2]
[ 0 0 7]]
- Classification Report:
- Accuracy: 93%
- High precision and recall for most classes
- Decision Trees do not require feature scaling.
- Pre-pruning and hyperparameter tuning can prevent overfitting.
- Visualization helps in interpreting the decision-making process of the model.