Skip to content

Commit 7ba3c9b

Browse files
committed
add in flag for numeric columns to set_column_choices
1 parent 7d19c1d commit 7ba3c9b

7 files changed

Lines changed: 50 additions & 49 deletions

File tree

countess/core/parameters.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
logger = logging.getLogger(__name__)
1515

1616

17-
def make_prefix_groups(strings: List[str]) -> Dict[str, List[str]]:
17+
def make_prefix_groups(strings: Iterable[str]) -> Dict[str, List[str]]:
1818
groups: Dict[str, List[str]] = {}
1919
for s in strings:
2020
if m := re.match(r"(.*?_+)([^_]+)$", s):
@@ -50,7 +50,7 @@ def get_parameters(self, key, base_dir="."):
5050
def get_hash_value(self):
5151
raise NotImplementedError(f"Implement {self.__class__.__name__}.get_hash_value()")
5252

53-
def set_column_choices(self, choices: List[str]):
53+
def set_column_choices(self, choices: Mapping[str, bool]):
5454
pass
5555

5656

@@ -546,13 +546,18 @@ class ColumnChoiceParam(ChoiceParam):
546546
"""A ChoiceParam which DaskTransformPlugin knows
547547
it should automatically update with a list of columns"""
548548

549-
def set_column_choices(self, choices: List[str]):
550-
self.set_choices(list(choices))
549+
def set_column_choices(self, choices: Mapping[str, bool]):
550+
self.set_choices(list(choices.keys()))
551551

552552
def get_column(self, df):
553553
return _dataframe_get_column(df, self.value)
554554

555555

556+
class NumericColumnChoiceParam(ColumnChoiceParam):
557+
def set_column_choices(self, choices: Mapping[str, bool]):
558+
self.set_choices([c for c, n in choices.items() if n])
559+
560+
556561
class ColumnOrNoneChoiceParam(ColumnChoiceParam):
557562
DEFAULT_VALUE = "— NONE —"
558563

@@ -573,8 +578,8 @@ def get_column(self, df):
573578

574579

575580
class ColumnGroupChoiceParam(ChoiceParam):
576-
def set_column_choices(self, choices: List[str]):
577-
self.set_choices([n + "*" for n in make_prefix_groups(choices).keys()])
581+
def set_column_choices(self, choices: Mapping[str, bool]):
582+
self.set_choices([n + "*" for n in make_prefix_groups(choices.keys()).keys()])
578583

579584
def get_column_prefix(self):
580585
return self.value.removesuffix("*")
@@ -634,7 +639,7 @@ class ColumnOrStringParam(ColumnChoiceParam):
634639
DEFAULT_VALUE: Any = ""
635640
PREFIX = "— "
636641

637-
def set_column_choices(self, choices: List[str]):
642+
def set_column_choices(self, choices: Mapping[str, bool]):
638643
self.set_choices([self.PREFIX + c for c in choices])
639644

640645
def get_column_name(self) -> Optional[str]:
@@ -678,6 +683,9 @@ def set_choices(self, choices: Iterable[str]):
678683
class ColumnOrIntegerParam(ColumnOrStringParam):
679684
DEFAULT_VALUE: int = 0
680685

686+
def set_column_choices(self, choices: Mapping[str, bool]):
687+
self.set_choices([self.PREFIX + c for c, n in choices.items() if n])
688+
681689
def __init__(
682690
self,
683691
label: str,
@@ -755,7 +763,7 @@ def set_parameter(self, key: str, value: Union[bool, int, float, str], base_dir:
755763
elif isinstance(param, ScalarParam):
756764
param.set_value(value)
757765

758-
def set_column_choices(self, choices: List[str]):
766+
def set_column_choices(self, choices: Mapping[str, bool]):
759767
logger.debug("HasSubParametersMixin.set_column_choices %s", choices)
760768
for p in self.params.values():
761769
p.set_column_choices(choices)
@@ -856,7 +864,7 @@ def get_hash_value(self):
856864
digest.update(p.get_hash_value().encode("utf-8"))
857865
return digest.hexdigest()
858866

859-
def set_column_choices(self, choices: List[str]):
867+
def set_column_choices(self, choices: Mapping[str, bool]):
860868
self.param.set_column_choices(choices)
861869
for p in self.params:
862870
p.set_column_choices(choices)
@@ -897,10 +905,10 @@ def get_parameters(self, key, base_dir="."):
897905
yield f"{key}.{n}._label", p.label
898906
yield from p.get_parameters(f"{key}.{n}", base_dir)
899907

900-
def set_column_choices(self, choices: List[str]):
908+
def set_column_choices(self, choices: Mapping[str, bool]):
901909
params_by_label = {p.label: p for p in self.params}
902910
self.params = []
903-
for label in choices:
911+
for label in choices.keys():
904912
if label in params_by_label:
905913
self.params.append(params_by_label[label])
906914
else:
@@ -913,6 +921,11 @@ def get_column_params(self):
913921
yield p.label, p
914922

915923

924+
class PerNumericColumnArrayParam(PerColumnArrayParam):
925+
def set_column_choices(self, choices: Mapping[str, bool]):
926+
super().set_column_choices({k: True for k, v in choices.items() if v})
927+
928+
916929
class FileArrayParam(ArrayParam):
917930
"""FileArrayParam is an ArrayParam arranged per-file. Using this class
918931
really just marks it as expecting to be populated from an open file

countess/core/plugins.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
HasSubParametersMixin,
3737
MultiParam,
3838
)
39-
from countess.utils.duckdb import duckdb_combine, duckdb_escape_literal
39+
from countess.utils.duckdb import duckdb_combine, duckdb_dtype_is_numeric, duckdb_escape_literal
4040
from countess.utils.files import clean_filename
4141
from countess.utils.pyarrow import python_type_to_arrow_dtype
4242

@@ -137,7 +137,10 @@ def prepare_multi(self, ddbc: DuckDBPyConnection, sources: Mapping[str, DuckDBPy
137137
self.prepare(ddbc, duckdb_combine(ddbc, list(sources.values())))
138138

139139
def prepare(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation]) -> None:
140-
self.set_column_choices([] if source is None else source.columns)
140+
if source is None:
141+
self.set_column_choices({})
142+
else:
143+
self.set_column_choices({c: duckdb_dtype_is_numeric(d) for c, d in zip(source.columns, source.dtypes)})
141144

142145
def execute_multi(
143146
self, ddbc: DuckDBPyConnection, sources: Mapping[str, DuckDBPyRelation], row_limit: Optional[int] = None

countess/plugins/correlation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from duckdb import DuckDBPyConnection, DuckDBPyRelation
55

66
from countess import VERSION
7-
from countess.core.parameters import ColumnChoiceParam, ColumnOrNoneChoiceParam
7+
from countess.core.parameters import NumericColumnChoiceParam, ColumnOrNoneChoiceParam
88
from countess.core.plugins import DuckdbSimplePlugin
99
from countess.utils.duckdb import duckdb_escape_identifier
1010

@@ -18,8 +18,8 @@ class CorrelationPlugin(DuckdbSimplePlugin):
1818
link = "https://countess-project.github.io/CountESS/included-plugins/#correlation-tool"
1919

2020
group = ColumnOrNoneChoiceParam("Group")
21-
column1 = ColumnChoiceParam("Column X")
22-
column2 = ColumnChoiceParam("Column Y")
21+
column1 = NumericColumnChoiceParam("Column X")
22+
column2 = NumericColumnChoiceParam("Column Y")
2323

2424
columns: list[str] = []
2525
dataframes: list[pd.DataFrame] = []

countess/plugins/score_scale.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ArrayParam,
99
ChoiceParam,
1010
ColumnChoiceParam,
11+
NumericColumnChoiceParam,
1112
ColumnOrNoneChoiceParam,
1213
FloatParam,
1314
StringParam,
@@ -45,7 +46,7 @@ class ScoreScalingPlugin(DuckdbSimplePlugin):
4546
description = "Scaled Scores using variant classification"
4647
version = VERSION
4748

48-
score_col = ColumnChoiceParam("Score Column")
49+
score_col = NumericColumnChoiceParam("Score Column")
4950
classifiers = ArrayParam("Variant Classifiers", ScaleClassParam("Class"), min_size=2, max_size=2, read_only=True)
5051
group_col = ColumnOrNoneChoiceParam("Group By")
5152

countess/plugins/vampseq.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,27 @@
44
from duckdb import DuckDBPyConnection, DuckDBPyRelation
55

66
from countess import VERSION
7-
from countess.core.parameters import (
8-
ArrayParam,
9-
ColumnChoiceParam,
10-
FloatParam,
11-
TabularMultiParam,
12-
PerColumnArrayParam,
13-
)
7+
from countess.core.parameters import FloatParam, PerNumericColumnArrayParam, TabularMultiParam
148
from countess.core.plugins import DuckdbSimplePlugin
15-
from countess.utils.duckdb import duckdb_escape_identifier, duckdb_escape_literal, duckdb_dtype_is_numeric
9+
from countess.utils.duckdb import duckdb_escape_identifier, duckdb_escape_literal
1610

1711
logger = logging.getLogger(__name__)
1812

1913

2014
class CountColumnParam(TabularMultiParam):
2115
weight = FloatParam("Weight")
2216

17+
2318
class VampSeqScorePlugin(DuckdbSimplePlugin):
2419
name = "VAMP-seq Scoring"
2520
description = "Calculate scores from weighed bin counts"
2621
version = VERSION
2722

28-
columns = PerColumnArrayParam("Columns", CountColumnParam("Column"))
29-
30-
def prepare(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation]) -> None:
31-
# Override prepare to only select numeric columns.
32-
if source is None:
33-
self.set_column_choices([])
34-
else:
35-
self.set_column_choices([
36-
c for c, d in zip(source.columns, source.dtypes)
37-
if duckdb_dtype_is_numeric(d)
38-
])
23+
columns = PerNumericColumnArrayParam("Columns", CountColumnParam("Column"))
3924

4025
def execute(
4126
self, ddbc: DuckDBPyConnection, source: DuckDBPyRelation, row_limit: Optional[int] = None
4227
) -> Optional[DuckDBPyRelation]:
43-
4428
weighted_columns = {
4529
duckdb_escape_identifier(name): duckdb_escape_literal(param.weight.value)
4630
for name, param in self.columns.get_column_params()
@@ -50,8 +34,8 @@ def execute(
5034
if not weighted_columns:
5135
return source
5236

53-
weighted_counts = ' + '.join(f"{k} * {v}" for k, v in weighted_columns.items())
54-
total_counts = ' + '.join(k for k in weighted_columns.keys())
37+
weighted_counts = " + ".join(f"{k} * {v}" for k, v in weighted_columns.items())
38+
total_counts = " + ".join(k for k in weighted_columns.keys())
5539

5640
proj = f"*, ({weighted_counts}) / ({total_counts}) as score"
5741

tests/plugins/test_variant.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test_variant_ref_value():
1111
plugin.set_parameter("reference", "AGAAGTAGAGG")
1212
plugin.set_parameter("variant.seq_type", "g")
1313
plugin.set_parameter("variant.output", "out")
14-
plugin.set_column_choices(["seq"])
14+
plugin.set_column_choices({"seq": False})
1515

1616
assert plugin.transform({"seq": "TGAAGTAGAGG"})["out"] == "g.1A>T"
1717
assert plugin.transform({"seq": "AGAAGTTGTGG"})["out"] == "g.[7A>T;9A>T]"
@@ -24,7 +24,7 @@ def test_variant_ref_column():
2424
plugin.set_parameter("reference", "— ref")
2525
plugin.set_parameter("variant.seq_type", "g")
2626
plugin.set_parameter("variant.output", "out")
27-
plugin.set_column_choices(["seq", "ref"])
27+
plugin.set_column_choices({"seq": False, "ref": False})
2828

2929
assert plugin.transform({"ref": "TACACACAG", "seq": "TACAGACAG"})["out"] == "g.5C>G"
3030
assert plugin.transform({"ref": "ATGGTTGGTTC", "seq": "ATGGTTGGTGGTTCG"})["out"] == "g.[7_9dup;11_12insG]"
@@ -37,7 +37,7 @@ def test_variant_ref_offset():
3737
plugin.set_parameter("variant.offset", "— offs")
3838
plugin.set_parameter("variant.seq_type", "g")
3939
plugin.set_parameter("variant.output", "out")
40-
plugin.set_column_choices(["seq", "offs"])
40+
plugin.set_column_choices({"seq": False, "offs": True})
4141

4242
assert plugin.transform({"seq": "TGAAGTAGAGG", "offs": "0"})["out"] == "g.1A>T"
4343
assert plugin.transform({"seq": "AGAAGTTGTGG", "offs": "10"})["out"] == "g.[17A>T;19A>T]"
@@ -65,7 +65,7 @@ def test_variant_ref_offset_minus():
6565
plugin.set_parameter("variant.seq_type", "g")
6666
plugin.set_parameter("variant.minus_strand", True)
6767
plugin.set_parameter("variant.output", "out")
68-
plugin.set_column_choices(["seq", "offs"])
68+
plugin.set_column_choices({"seq": False, "offs": True})
6969

7070
assert plugin.transform({"seq": "TGAAGTAGAGG"})["out"] == "g.1011T>A"
7171
assert plugin.transform({"seq": "AGAAGTTGTGG"})["out"] == "g.[1003T>A;1005T>A]"
@@ -79,7 +79,7 @@ def test_variant_too_many():
7979
plugin.set_parameter("variant.seq_type", "g")
8080
plugin.set_parameter("variant.output", "out")
8181
plugin.set_parameter("variant.maxlen", 2)
82-
plugin.set_column_choices(["seq", "offs"])
82+
plugin.set_column_choices({"seq": False, "offs": True})
8383

8484
assert plugin.transform({"seq": "TGAAGTAGAGG"})["out"] == "g.1A>T"
8585
assert plugin.transform({"seq": "AGAAGTTGTGG"})["out"] == "g.[7A>T;9A>T]"

tests/test_parameters.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def test_coindex():
226226
def test_columnorintegerparam():
227227
df = pd.DataFrame([[1, 2], [3, 4]], columns=["a", "b"])
228228
cp = ColumnOrIntegerParam("x")
229-
cp.set_column_choices(["a", "b"])
229+
cp.set_column_choices({"a": True, "b": True})
230230

231231
assert cp.get_column_name() is None
232232

@@ -240,7 +240,7 @@ def test_columnorintegerparam():
240240
assert cp.get_column_name() == "a"
241241
assert isinstance(cp.get_column_or_value(df, False), pd.Series)
242242

243-
cp.set_column_choices(["c", "d"])
243+
cp.set_column_choices({"c": True, "d": True})
244244
assert cp.choice is None
245245

246246
cp.value = "hello"
@@ -250,7 +250,7 @@ def test_columnorintegerparam():
250250
def test_columngroup():
251251
df = pd.DataFrame([], columns=["one_two", "one_three", "two_one", "two_two", "two_three", "three_four_five"])
252252
cp = ColumnGroupOrNoneChoiceParam("x")
253-
cp.set_column_choices(df.columns)
253+
cp.set_column_choices({c: True for c in df.columns})
254254
assert cp.is_none()
255255
assert "one_*" in cp.choices
256256
assert "two_*" in cp.choices
@@ -330,11 +330,11 @@ def test_pcap():
330330
pp = IntegerParam("x")
331331
ap = PerColumnArrayParam("y", param=pp)
332332

333-
ap.set_column_choices(["a", "b", "c"])
333+
ap.set_column_choices({"a": True, "b": True, "c": True })
334334
assert len(ap) == 3
335335
apa, apb, apc = list(ap)
336336

337-
ap.set_column_choices(["c", "d", "b", "a"])
337+
ap.set_column_choices({"c": True, "d": True, "b": True, "a": True})
338338
assert len(ap) == 4
339339
assert ap[0] is apc
340340
assert ap[2] is apb

0 commit comments

Comments
 (0)