Skip to content

Commit 2044833

Browse files
committed
fix: adjust XGBoost max_bin and improve BayesSearch param validation
Add _adjust_xgboost_parameters to dynamically enforce max_bin >= 2 for XGBoost, preventing errors on small datasets/folds. Refactor Bayesian search parameter validation to use _is_simple_categorical helper, correctly identifying hashable list parameters for auto-wrapping in Categorical. Change logging level from warning to info for auto-corrections in parameter spaces.
1 parent 293dd26 commit 2044833

1 file changed

Lines changed: 64 additions & 14 deletions

File tree

ml_grid/pipeline/grid_search_cross_validate.py

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -359,17 +359,29 @@ def __init__(
359359
# Ensure list-based parameters are wrapped in Categorical for Bayesian search
360360
if self.global_params.bayessearch:
361361
self.logger.debug("Validating parameter space for Bayesian search...")
362+
363+
def _is_simple_categorical(val):
364+
"""
365+
Determines if a value is a list of simple, hashable choices
366+
suitable for wrapping in skopt.space.Categorical.
367+
"""
368+
# A list is only considered a categorical choice if it has more than one item.
369+
# Single-item lists are treated as fixed parameters by BayesSearchCV.
370+
if not isinstance(val, (list, np.ndarray)) or len(val) <= 1:
371+
return False
372+
try:
373+
for item in val:
374+
hash(item)
375+
return True
376+
except TypeError:
377+
return False
378+
362379
if isinstance(parameter_space, list):
363380
for i, space in enumerate(parameter_space):
364381
new_space = {}
365382
for key, value in space.items():
366-
is_list_of_choices = (
367-
isinstance(value, (list, np.ndarray))
368-
and value
369-
and not isinstance(value[0], list)
370-
)
371-
if is_list_of_choices and not is_skopt_space(value):
372-
self.logger.warning(
383+
if _is_simple_categorical(value) and not is_skopt_space(value):
384+
self.logger.info(
373385
f"Auto-correcting param '{key}' for BayesSearch: wrapping list in Categorical."
374386
)
375387
new_space[key] = Categorical(value)
@@ -379,13 +391,8 @@ def __init__(
379391
elif isinstance(parameter_space, dict):
380392
new_parameter_space = {}
381393
for key, value in parameter_space.items():
382-
is_list_of_choices = (
383-
isinstance(value, (list, np.ndarray))
384-
and value
385-
and not isinstance(value[0], list)
386-
)
387-
if is_list_of_choices and not is_skopt_space(value):
388-
self.logger.warning(
394+
if _is_simple_categorical(value) and not is_skopt_space(value):
395+
self.logger.info(
389396
f"Auto-correcting param '{key}' for BayesSearch: wrapping list in Categorical."
390397
)
391398
new_parameter_space[key] = Categorical(value)
@@ -483,6 +490,13 @@ def __init__(
483490
"Adjusted CatBoost subsample parameter space to prevent errors on small CV folds."
484491
)
485492

493+
# Adjust XGBoost parameters for small datasets
494+
if "xgb" in method_name.lower():
495+
self._adjust_xgboost_parameters(parameter_space)
496+
self.logger.debug(
497+
"Adjusted XGBoost max_bin parameter space to prevent errors on small CV folds."
498+
)
499+
486500
# Force sequential search for H2O/GPU models
487501
original_grid_n_jobs = self.global_parameters.grid_n_jobs
488502
if is_gpu_model or is_h2o_model:
@@ -1111,6 +1125,42 @@ def adjust_param(param_value):
11111125
elif isinstance(parameter_space, dict) and "rsm" in parameter_space:
11121126
parameter_space["rsm"] = adjust_param(parameter_space["rsm"])
11131127

1128+
def _adjust_xgboost_parameters(self, parameter_space: Union[Dict, List[Dict]]):
1129+
"""
1130+
Dynamically adjusts 'max_bin' for XGBoost to prevent errors on small datasets.
1131+
"""
1132+
# Ensure max_bin is at least 2
1133+
min_max_bin = 2
1134+
1135+
def adjust_param(param_value):
1136+
if is_skopt_space(param_value):
1137+
# For skopt Integer space
1138+
if hasattr(param_value, "low"):
1139+
new_low = max(param_value.low, min_max_bin)
1140+
if new_low > param_value.high:
1141+
# If low > high, we must adjust high as well
1142+
param_value.high = max(param_value.high, min_max_bin)
1143+
new_low = min(new_low, param_value.high)
1144+
param_value.low = new_low
1145+
elif isinstance(param_value, (list, np.ndarray)):
1146+
# Filter out invalid values
1147+
new_param_value = [v for v in param_value if v >= min_max_bin]
1148+
if not new_param_value:
1149+
return [min_max_bin]
1150+
return new_param_value
1151+
# If it's a single scalar (int), ensure it's >= 2
1152+
elif isinstance(param_value, (int, float)):
1153+
if param_value < min_max_bin:
1154+
return min_max_bin
1155+
return param_value
1156+
1157+
if isinstance(parameter_space, list):
1158+
for params in parameter_space:
1159+
if "max_bin" in params:
1160+
params["max_bin"] = adjust_param(params["max_bin"])
1161+
elif isinstance(parameter_space, dict) and "max_bin" in parameter_space:
1162+
parameter_space["max_bin"] = adjust_param(parameter_space["max_bin"])
1163+
11141164
def _shutdown_h2o_if_needed(self, algorithm: Any):
11151165
"""Safely shuts down the H2O cluster if the algorithm is an H2O model."""
11161166
# Use the module-level tuple

0 commit comments

Comments
 (0)