1414logger = logging .getLogger (__name__ )
1515
1616
17- def make_prefix_groups (strings : List [str ]) -> Dict [str , List [str ]]:
17+ def make_prefix_groups (strings : Iterable [str ]) -> Dict [str , List [str ]]:
1818 groups : Dict [str , List [str ]] = {}
1919 for s in strings :
2020 if m := re .match (r"(.*?_+)([^_]+)$" , s ):
@@ -50,7 +50,7 @@ def get_parameters(self, key, base_dir="."):
5050 def get_hash_value (self ):
5151 raise NotImplementedError (f"Implement { self .__class__ .__name__ } .get_hash_value()" )
5252
53- def set_column_choices (self , choices : List [str ]):
53+ def set_column_choices (self , choices : Mapping [str , bool ]):
5454 pass
5555
5656
@@ -546,13 +546,18 @@ class ColumnChoiceParam(ChoiceParam):
546546 """A ChoiceParam which DaskTransformPlugin knows
547547 it should automatically update with a list of columns"""
548548
549- def set_column_choices (self , choices : List [str ]):
550- self .set_choices (list (choices ))
549+ def set_column_choices (self , choices : Mapping [str , bool ]):
550+ self .set_choices (list (choices . keys () ))
551551
552552 def get_column (self , df ):
553553 return _dataframe_get_column (df , self .value )
554554
555555
556+ class NumericColumnChoiceParam (ColumnChoiceParam ):
557+ def set_column_choices (self , choices : Mapping [str , bool ]):
558+ self .set_choices ([c for c , n in choices .items () if n ])
559+
560+
556561class ColumnOrNoneChoiceParam (ColumnChoiceParam ):
557562 DEFAULT_VALUE = "— NONE —"
558563
@@ -573,8 +578,8 @@ def get_column(self, df):
573578
574579
575580class ColumnGroupChoiceParam (ChoiceParam ):
576- def set_column_choices (self , choices : List [str ]):
577- self .set_choices ([n + "*" for n in make_prefix_groups (choices ).keys ()])
581+ def set_column_choices (self , choices : Mapping [str , bool ]):
582+ self .set_choices ([n + "*" for n in make_prefix_groups (choices . keys () ).keys ()])
578583
579584 def get_column_prefix (self ):
580585 return self .value .removesuffix ("*" )
@@ -634,7 +639,7 @@ class ColumnOrStringParam(ColumnChoiceParam):
634639 DEFAULT_VALUE : Any = ""
635640 PREFIX = "— "
636641
637- def set_column_choices (self , choices : List [str ]):
642+ def set_column_choices (self , choices : Mapping [str , bool ]):
638643 self .set_choices ([self .PREFIX + c for c in choices ])
639644
640645 def get_column_name (self ) -> Optional [str ]:
@@ -678,6 +683,9 @@ def set_choices(self, choices: Iterable[str]):
678683class ColumnOrIntegerParam (ColumnOrStringParam ):
679684 DEFAULT_VALUE : int = 0
680685
686+ def set_column_choices (self , choices : Mapping [str , bool ]):
687+ self .set_choices ([self .PREFIX + c for c , n in choices .items () if n ])
688+
681689 def __init__ (
682690 self ,
683691 label : str ,
@@ -755,7 +763,7 @@ def set_parameter(self, key: str, value: Union[bool, int, float, str], base_dir:
755763 elif isinstance (param , ScalarParam ):
756764 param .set_value (value )
757765
758- def set_column_choices (self , choices : List [str ]):
766+ def set_column_choices (self , choices : Mapping [str , bool ]):
759767 logger .debug ("HasSubParametersMixin.set_column_choices %s" , choices )
760768 for p in self .params .values ():
761769 p .set_column_choices (choices )
@@ -856,7 +864,7 @@ def get_hash_value(self):
856864 digest .update (p .get_hash_value ().encode ("utf-8" ))
857865 return digest .hexdigest ()
858866
859- def set_column_choices (self , choices : List [str ]):
867+ def set_column_choices (self , choices : Mapping [str , bool ]):
860868 self .param .set_column_choices (choices )
861869 for p in self .params :
862870 p .set_column_choices (choices )
@@ -897,10 +905,10 @@ def get_parameters(self, key, base_dir="."):
897905 yield f"{ key } .{ n } ._label" , p .label
898906 yield from p .get_parameters (f"{ key } .{ n } " , base_dir )
899907
900- def set_column_choices (self , choices : List [str ]):
908+ def set_column_choices (self , choices : Mapping [str , bool ]):
901909 params_by_label = {p .label : p for p in self .params }
902910 self .params = []
903- for label in choices :
911+ for label in choices . keys () :
904912 if label in params_by_label :
905913 self .params .append (params_by_label [label ])
906914 else :
@@ -913,6 +921,11 @@ def get_column_params(self):
913921 yield p .label , p
914922
915923
924+ class PerNumericColumnArrayParam (PerColumnArrayParam ):
925+ def set_column_choices (self , choices : Mapping [str , bool ]):
926+ super ().set_column_choices ({k : True for k , v in choices .items () if v })
927+
928+
916929class FileArrayParam (ArrayParam ):
917930 """FileArrayParam is an ArrayParam arranged per-file. Using this class
918931 really just marks it as expecting to be populated from an open file
0 commit comments