forked from explainX/explainx
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprotodash.py
More file actions
120 lines (84 loc) · 3.12 KB
/
protodash.py
File metadata and controls
120 lines (84 loc) · 3.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from .imports import *
from .encode_decode_cat_col import *
from .rescale_numeric_feature import *
from .PDASH import ProtodashExplainer
from .PDASH_utils import *
"""
This class calculates similar prototypes
Input:
"""
class protodash():
def __init__(self):
super(protodash, self).__init__()
self.encode= None
self.get_col=None
self.cat_col=[]
self.actual_variables= []
self.classification= False
self.df=None
self.y_variable= None
def encode_categorical_var(self, df):
if self.get_col==None:
self.get_col= get_cols()
num_col, self.cat_col = self.get_col.get_cate_numer_col(df)
if self.encode==None:
self.encode= encode_decode_cat_col()
for col_name in self.cat_col:
df = self.encode.encode_col(df, col_name)
return df
def decode_categorical_var(self, df):
# let's decode
for col_name in self.cat_col:
df = self.encode.decode_col(df, col_name)
return df
def preprocess_data(self, df, y_variable):
#store y_var
self.y_variable=y_variable
# encode categorical var
df= self.encode_categorical_var(df)
# check and remove index column in the file.
try:
df = df.drop("index", axis=1)
except:
pass
# finad and remove variables containing shap values
self.df= self.remove_impact_columns(df)
self.classification= self.is_classification()
return True
def remove_impact_columns(self, df):
self.actual_variables = [col for col in df.columns if not '_impact' in col]
self.actual_variables = [col for col in self.actual_variables if not '_rescaled' in col]
df = df[self.actual_variables]
return df
def is_classification(self):
count = self.df[self.y_variable].nunique()
if count>=10:
return False
else:
return True
def find_prototypes(self, row_number):
# get training data
data, Z= self.z_train_good(row_number)
#get prototypes
explainer = ProtodashExplainer()
try:
(W, S, setValues) = explainer.explain(Z, data, m=5, kernelType='other', sigma=5)
except:
(W, S, setValues) = explainer.explain(Z, data, m=5, kernelType='other', sigma=5) #Guassian gives an error.
#make a dataframe
dfs = pd.DataFrame.from_records(data[S, 0:].astype('double'))
dfs.columns = self.actual_variables
dfs["Weight(%)"] = (np.around(W, 5) / np.sum(np.around(W, 5)))*100
dfs = self.decode_categorical_var(dfs)
return dfs
def z_train_good(self, row_number):
row= self.df.iloc[row_number]
#remove row from the data
df = self.df.drop(self.df.index[row_number])
if self.classification==True:
predict_value= row[self.y_variable]
df= df[df[self.y_variable] == predict_value] # choose prediction: 0/1
row1 = row.values
array= df.values
Z= row1.reshape(-1, 1).T
return array, Z