Skip to content

Mai3Prabhu/Decision-Tree-Classifier

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 

Repository files navigation

Decision Tree Classifier on Iris Dataset

Project Overview

This project implements a Decision Tree Classifier on the famous Iris dataset to classify iris plants into three species: Setosa, Versicolor, and Virginica. The project includes pre-pruning, hyperparameter tuning using GridSearchCV, and visualization of the decision tree.

Dataset

The Iris dataset is a classic dataset in machine learning and pattern recognition. It contains 150 instances and 4 numeric features:

  • Sepal length (cm)
  • Sepal width (cm)
  • Petal length (cm)
  • Petal width (cm)

The dataset has 3 classes, each with 50 samples.

Features

  • Decision Tree Classifier implementation from sklearn
  • Train/Test split to evaluate model performance
  • Decision tree visualization for easy interpretation
  • Hyperparameter tuning with GridSearchCV
  • Evaluation metrics: Accuracy, Precision, Recall, F1-score, Confusion Matrix

Installation

  1. Clone the repository:
    git clone <repository_url>
  2. Install dependencies:
    pip install -r requirements.txt

Usage

# Import libraries
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import matplotlib.pyplot as plt

# Load dataset
iris = load_iris()
X = pd.DataFrame(iris['data'], columns=iris['feature_names'])
y = iris['target']

# Split dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=10)

# Train Decision Tree Classifier
model = DecisionTreeClassifier(max_depth=2)
model.fit(X_train, y_train)

# Visualize Decision Tree
plt.figure(figsize=(15,10))
plot_tree(model, filled=True)

# Evaluate model
y_pred = model.predict(X_test)
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))
print("Accuracy:", accuracy_score(y_test, y_pred))

Hyperparameter Tuning

param_grid = {
    'criterion': ['gini', 'entropy', 'log_loss'],
    'splitter': ['best', 'random'],
    'max_depth': [1, 2, 3, 4, 5],
    'max_features': ['auto', 'sqrt', 'log2']
}

grid = GridSearchCV(DecisionTreeClassifier(), param_grid, cv=5, scoring='accuracy')
grid.fit(X_train, y_train)

print(grid.best_params_)
print("Best cross-validation score:", grid.best_score_)

Results

  • Confusion Matrix:
[[10  0  0]
 [ 0 11  2]
 [ 0  0  7]]
  • Classification Report:
    • Accuracy: 93%
    • High precision and recall for most classes

Key Learnings

  • Decision Trees do not require feature scaling.
  • Pre-pruning and hyperparameter tuning can prevent overfitting.
  • Visualization helps in interpreting the decision-making process of the model.

About

A Decision Tree Classifier project on the Iris dataset, including pre-pruning, hyperparameter tuning with GridSearchCV, and visualization to achieve high classification accuracy

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors