Skip to content

Commit 90dca14

Browse files
committed
Obtain expected category levels from fitted model
1 parent 0315a60 commit 90dca14

1 file changed

Lines changed: 21 additions & 1 deletion

File tree

pySEQTarget/helpers/_predict_model.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,27 @@
33

44
def _predict_model(self, model, newdata):
55
newdata = newdata.to_pandas()
6+
7+
# Original behavior - convert fixed_cols to category
68
for col in self.fixed_cols:
79
if col in newdata.columns:
810
newdata[col] = newdata[col].astype("category")
9-
return np.array(model.predict(newdata))
11+
12+
try:
13+
return np.array(model.predict(newdata))
14+
except Exception as e:
15+
if "mismatching levels" in str(e):
16+
# Fix category levels from model's design_info
17+
if hasattr(model, 'model') and hasattr(model.model, 'data') and hasattr(model.model.data, 'design_info'):
18+
design_info = model.model.data.design_info
19+
for factor, factor_info in design_info.factor_infos.items():
20+
if factor_info.type == 'categorical':
21+
col_name = factor.name()
22+
if col_name in newdata.columns:
23+
expected_categories = list(factor_info.categories)
24+
newdata[col_name] = newdata[col_name].astype(str)
25+
newdata[col_name] = newdata[col_name].astype('category')
26+
newdata[col_name] = newdata[col_name].cat.set_categories(expected_categories)
27+
return np.array(model.predict(newdata))
28+
else:
29+
raise

0 commit comments

Comments
 (0)