11from .ccontree import Config , solve , Tree
22from sklearn .base import BaseEstimator
3- from sklearn .utils .validation import check_array , check_is_fitted
3+ from sklearn .utils .validation import check_array , check_is_fitted , validate_data
44from sklearn .utils ._param_validation import Interval , StrOptions
55from sklearn .metrics import accuracy_score
66from pycontree .export import TreeExporter
@@ -56,7 +56,7 @@ def _process_fit_data(self, X, y):
5656 """
5757 with warnings .catch_warnings ():
5858 warnings .filterwarnings (action = "ignore" , category = FutureWarning )
59- X = self . _validate_data ( X , ensure_min_samples = 2 , dtype = np .float64 )
59+ X = validate_data ( self , X , ensure_min_samples = 2 , dtype = np .float64 )
6060
6161 y = check_array (y , ensure_2d = False , dtype = np .intc )
6262 self .n_classes_ = len (np .unique (y ))
@@ -70,7 +70,7 @@ def _process_score_data(self, X, y):
7070 """
7171 with warnings .catch_warnings ():
7272 warnings .filterwarnings (action = "ignore" , category = FutureWarning )
73- X = self . _validate_data ( X , reset = False , dtype = np .float64 )
73+ X = validate_data ( self , X , reset = False , dtype = np .float64 )
7474
7575
7676 y = check_array (y , ensure_2d = False , dtype = np .intc )
@@ -84,7 +84,7 @@ def _process_predict_data(self, X):
8484 """
8585 with warnings .catch_warnings ():
8686 warnings .filterwarnings (action = "ignore" , category = FutureWarning )
87- return self . _validate_data ( X , reset = False , dtype = np .float64 )
87+ return validate_data ( self , X , reset = False , dtype = np .float64 )
8888
8989 def fit (self , X , y ) -> Self :
9090 """
@@ -136,6 +136,48 @@ def predict(self, X):
136136
137137 return self .tree_ .predict (X )
138138
139+ def predict_proba (self , X ):
140+ """
141+ Predicts the probabilities of the target class for the given input feature data.
142+
143+ Args:
144+ X : array-like, shape = (n_samples, n_features)
145+ Data matrix
146+
147+ Returns:
148+ numpy.ndarray: A 2D array that represents the predicted class probabilities of the test data.
149+ The i-j-th element in this array corresponds to the predicted class probablity for the j-th class of the i-th instance in `X`.
150+ """
151+ check_is_fitted (self , "tree_" )
152+ X = self ._process_predict_data (X )
153+ probabilities = np .zeros ((len (X ), self .n_classes_ ))
154+ train_data = (self .train_X_ , self .train_y_ )
155+ self ._recursive_predict_proba (self .tree_ , probabilities , np .array (range (0 , len (X ))), X , train_data )
156+ # Check that all rows sum to proability 1 (account for floating errors)
157+ assert (probabilities .sum (axis = 1 ).min () >= 1 - 1e-4 )
158+ return probabilities
159+
160+ def _recursive_predict_proba (self , tree , probabilities , indices , X , train_data ):
161+ train_X = train_data [0 ]
162+ train_y = train_data [1 ]
163+ if tree .is_leaf_node ():
164+ n = len (train_y )
165+ assert (n > 0 )
166+ all_counts = np .zeros (self .n_classes_ )
167+ unique , counts = np .unique (train_y , return_counts = True )
168+ for label , count in zip (unique , counts ):
169+ all_counts [label ] = count
170+ probs = all_counts / n
171+ probabilities [indices ] = probs
172+ else :
173+ indices_left = np .intersect1d (np .argwhere ( (X [:, tree .get_split_feature ()] <= tree .get_split_threshold ())), indices )
174+ indices_right = np .intersect1d (np .argwhere (~ (X [:, tree .get_split_feature ()] <= tree .get_split_threshold ())), indices )
175+ sel = train_X [:, tree .get_split_feature ()] <= tree .get_split_threshold ()
176+ train_data_left = (train_X [ sel , :], train_y [ sel ])
177+ train_data_right = (train_X [~ sel , :], train_y [~ sel ])
178+ self ._recursive_predict_proba (tree .get_left (), probabilities , indices_left , X , train_data_left )
179+ self ._recursive_predict_proba (tree .get_right (), probabilities , indices_right , X , train_data_right )
180+
139181 def score (self , X , y_true ) -> float :
140182 """
141183 Computes the score for the given input feature data
0 commit comments