Skip to content

Commit d10662a

Browse files
committed
frequency calculator plugin, not sure if this is staying
1 parent e71f58c commit d10662a

2 files changed

Lines changed: 56 additions & 0 deletions

File tree

countess/plugins/frequency.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import logging
2+
from typing import Iterable, Optional
3+
4+
from duckdb import DuckDBPyConnection, DuckDBPyRelation
5+
6+
from countess import VERSION
7+
from countess.core.parameters import BooleanParam, ColumnOrNoneChoiceParam, PerNumericColumnArrayParam
8+
from countess.core.plugins import DuckdbSqlPlugin
9+
from countess.utils.duckdb import duckdb_escape_identifier
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class FrequencyPlugin(DuckdbSqlPlugin):
15+
name = "Calculate Frequencies"
16+
description = "Calculate frequencies from counts"
17+
version = VERSION
18+
19+
columns = PerNumericColumnArrayParam("Columns", BooleanParam("Convert?"))
20+
group_col = ColumnOrNoneChoiceParam("Group By")
21+
22+
def prepare(self, ddbc: DuckDBPyConnection, source: Optional[DuckDBPyRelation]) -> None:
23+
super().prepare(ddbc, source)
24+
25+
# set default values for converting "count" columns
26+
if not any(cp.value for cp in self.columns):
27+
for cp in self.columns:
28+
if 'count' in cp.label:
29+
cp.value = True
30+
31+
def sql(self, table_name: str, columns: Iterable[str]) -> Optional[str]:
32+
sums = ", ".join(
33+
f"sum(T0.{duckdb_escape_identifier(cp.label)}) as sum_{n}" for n, cp in enumerate(self.columns) if cp.value
34+
)
35+
36+
if not sums:
37+
return None
38+
39+
freqs = ", ".join(
40+
f"CASE WHEN T1.sum_{n} > 0 THEN T0.{duckdb_escape_identifier(cp.label)} / T1.sum_{n} ELSE 0 END "
41+
f"as {duckdb_escape_identifier(cp.label + '_freq')}"
42+
for n, cp in enumerate(self.columns)
43+
if cp.value
44+
)
45+
46+
group_col_id = duckdb_escape_identifier(self.group_col.value) if self.group_col.is_not_none() else "1"
47+
48+
return f"""
49+
select T0.*, {freqs}
50+
from {table_name} T0 join (
51+
select {group_col_id} as score_group, {sums}
52+
from {table_name} T0
53+
group by score_group
54+
) T1 on ({group_col_id} = T1.score_group)
55+
"""

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ score = "countess.plugins.score:ScoringPlugin"
7171
score_scaling = "countess.plugins.score_scale:ScoreScalingPlugin"
7272
vampseq_score = "countess.plugins.vampseq:VampSeqScorePlugin"
7373
variant_classifier = "countess.plugins.variant:VariantClassifier"
74+
frequency = "countess.plugins.frequency:FrequencyPlugin"
7475

7576
[project.entry-points.gui_scripts]
7677
countess_gui = "countess.gui.main:main"

0 commit comments

Comments
 (0)