Skip to content

Commit 55c4349

Browse files
committed
add predict_proba, update to python 3.10, sklearn 1.6
1 parent 11dbb6f commit 55c4349

6 files changed

Lines changed: 54 additions & 9 deletions

File tree

.github/workflows/pip.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
fail-fast: false
2020
matrix:
2121
platform: [windows-latest, macos-latest, ubuntu-latest]
22-
python-version: ["3.8", "3.11"]
22+
python-version: ["3.10", "3.11", "3.12"]
2323

2424
runs-on: ${{ matrix.platform }}
2525

.github/workflows/pypi_release.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343

4444
- uses: pypa/cibuildwheel@v2.16.5
4545
env:
46-
CIBW_SKIP: "pp*"
46+
CIBW_SKIP: "pp* cp36-* cp37-* cp38-* cp39-*"
4747
CIBW_ARCHS_MACOS: auto universal2
4848
CIBW_PRERELEASE_PYTHONS: true
4949

@@ -67,7 +67,7 @@ jobs:
6767
- uses: actions/setup-python@v5
6868
name: Set up Python 3.x
6969
with:
70-
python-version: "3.8"
70+
python-version: "3.10"
7171

7272
- uses: actions/download-artifact@v4
7373
name: Download wheels

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "pycontree"
10-
version = "1.0.4"
10+
version = "1.0.5"
11+
requires-python = ">=3.10"
1112
description = "Python Wrapper ConTree: Optimal Decision Trees for Continuous Feature Data"
1213
license= {file = "LICENSE"}
1314
readme = "README.md"
@@ -23,7 +24,7 @@ maintainers = [
2324
dependencies =[
2425
"pandas>=1.0.0",
2526
"numpy>=1.18.0",
26-
"scikit-learn >=1.2.0",
27+
"scikit-learn >=1.6.0",
2728
"typing_extensions>=4.0.0"
2829

2930
]
@@ -33,6 +34,8 @@ classifiers = [
3334
"Operating System :: OS Independent", ]
3435

3536
[tool.setuptools.packages.find]
37+
where = ["python"]
38+
include = ["pycontree*"]
3639
exclude = ["datasets*", "train-datasets*", "examples*"]
3740

3841
[project.urls]
Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .ccontree import Config, solve, Tree
22
from sklearn.base import BaseEstimator
3-
from sklearn.utils.validation import check_array, check_is_fitted
3+
from sklearn.utils.validation import check_array, check_is_fitted, validate_data
44
from sklearn.utils._param_validation import Interval, StrOptions
55
from sklearn.metrics import accuracy_score
66
from pycontree.export import TreeExporter
@@ -56,7 +56,7 @@ def _process_fit_data(self, X, y):
5656
"""
5757
with warnings.catch_warnings():
5858
warnings.filterwarnings(action="ignore", category=FutureWarning)
59-
X = self._validate_data(X, ensure_min_samples=2, dtype=np.float64)
59+
X = validate_data(self, X, ensure_min_samples=2, dtype=np.float64)
6060

6161
y = check_array(y, ensure_2d=False, dtype=np.intc)
6262
self.n_classes_ = len(np.unique(y))
@@ -70,7 +70,7 @@ def _process_score_data(self, X, y):
7070
"""
7171
with warnings.catch_warnings():
7272
warnings.filterwarnings(action="ignore", category=FutureWarning)
73-
X = self._validate_data(X, reset=False, dtype=np.float64)
73+
X = validate_data(self, X, reset=False, dtype=np.float64)
7474

7575

7676
y = check_array(y, ensure_2d=False, dtype=np.intc)
@@ -84,7 +84,7 @@ def _process_predict_data(self, X):
8484
"""
8585
with warnings.catch_warnings():
8686
warnings.filterwarnings(action="ignore", category=FutureWarning)
87-
return self._validate_data(X, reset=False, dtype=np.float64)
87+
return validate_data(self, X, reset=False, dtype=np.float64)
8888

8989
def fit(self, X, y) -> Self:
9090
"""
@@ -136,6 +136,48 @@ def predict(self, X):
136136

137137
return self.tree_.predict(X)
138138

139+
def predict_proba(self, X):
140+
"""
141+
Predicts the probabilities of the target class for the given input feature data.
142+
143+
Args:
144+
X : array-like, shape = (n_samples, n_features)
145+
Data matrix
146+
147+
Returns:
148+
numpy.ndarray: A 2D array that represents the predicted class probabilities of the test data.
149+
The i-j-th element in this array corresponds to the predicted class probablity for the j-th class of the i-th instance in `X`.
150+
"""
151+
check_is_fitted(self, "tree_")
152+
X = self._process_predict_data(X)
153+
probabilities = np.zeros((len(X), self.n_classes_))
154+
train_data = (self.train_X_, self.train_y_)
155+
self._recursive_predict_proba(self.tree_, probabilities, np.array(range(0, len(X))), X, train_data)
156+
# Check that all rows sum to proability 1 (account for floating errors)
157+
assert (probabilities.sum(axis=1).min() >= 1-1e-4)
158+
return probabilities
159+
160+
def _recursive_predict_proba(self, tree, probabilities, indices, X, train_data):
161+
train_X = train_data[0]
162+
train_y = train_data[1]
163+
if tree.is_leaf_node():
164+
n = len(train_y)
165+
assert(n > 0)
166+
all_counts = np.zeros(self.n_classes_)
167+
unique, counts = np.unique(train_y, return_counts=True)
168+
for label, count in zip(unique, counts):
169+
all_counts[label] = count
170+
probs = all_counts / n
171+
probabilities[indices] = probs
172+
else:
173+
indices_left = np.intersect1d(np.argwhere( (X[:, tree.get_split_feature()] <= tree.get_split_threshold())), indices)
174+
indices_right = np.intersect1d(np.argwhere(~(X[:, tree.get_split_feature()] <= tree.get_split_threshold())), indices)
175+
sel = train_X[:, tree.get_split_feature()] <= tree.get_split_threshold()
176+
train_data_left = (train_X[ sel, :], train_y[ sel])
177+
train_data_right = (train_X[~sel, :], train_y[~sel])
178+
self._recursive_predict_proba(tree.get_left(), probabilities, indices_left, X, train_data_left)
179+
self._recursive_predict_proba(tree.get_right(), probabilities, indices_right, X, train_data_right)
180+
139181
def score(self, X, y_true) -> float:
140182
"""
141183
Computes the score for the given input feature data

0 commit comments

Comments
 (0)