Skip to content

Commit d966347

Browse files
committed
Fix some issues with pivots on string values #38
1 parent 5b39ec5 commit d966347

1 file changed

Lines changed: 24 additions & 11 deletions

File tree

countess/plugins/pivot.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from countess import VERSION
77
from countess.core.parameters import BooleanParam, ChoiceParam, PerColumnArrayParam
88
from countess.core.plugins import DuckdbSimplePlugin
9-
from countess.utils.duckdb import duckdb_choose_special, duckdb_escape_identifier, duckdb_escape_literal
9+
from countess.utils.duckdb import duckdb_choose_special, duckdb_escape_identifier, duckdb_escape_literal, duckdb_dtype_is_numeric
1010

1111
logger = logging.getLogger(__name__)
1212

@@ -40,6 +40,12 @@ def execute(
4040
if not expand_cols or not pivot_cols:
4141
return source
4242

43+
column_is_numeric = {
44+
name: duckdb_dtype_is_numeric(dt)
45+
for name, dt in zip(source.columns, source.dtypes)
46+
if name in expand_cols
47+
}
48+
4349
# Override the default pivoted column naming convention with
4450
# our own custom one ... this is what previous CountESS used
4551
# but it frankly could be better or customizable.
@@ -55,9 +61,15 @@ def execute(
5561
# definitively pick out the pivot output columns later ...
5662
pivot_char = duckdb_choose_special(index_cols + expand_cols)
5763

64+
# produces clauses like `SUM("foo") AS "~foo"` and the pivot will
65+
# then make columns like "value_~foo" (short names) or "column_value__~foo"
66+
# (long names)
5867
using_str = ", ".join(
59-
"%s(%s) AS %s"
60-
% (self.aggfunc.value, duckdb_escape_identifier(ec), duckdb_escape_identifier(pivot_char + ec))
68+
"%s(%s) AS %s" % (
69+
self.aggfunc.value if column_is_numeric[ec] else 'string_agg',
70+
duckdb_escape_identifier(ec),
71+
duckdb_escape_identifier(pivot_char + ec)
72+
)
6173
for ec in expand_cols
6274
)
6375
group_str = ", ".join(duckdb_escape_identifier(ic) for ic in index_cols)
@@ -69,14 +81,15 @@ def execute(
6981
logger.debug("PivotPlugin.execute query_str %s", query_str)
7082
rel = ddbc.sql(query_str)
7183

72-
project_str = f"COLUMNS('(.*)_{pivot_char}(.*)')"
73-
if self.default_0:
74-
project_str = f"COALESCE({project_str}, 0)"
75-
76-
if self.short_names:
77-
project_str += " AS '\\2_\\1'"
78-
else:
79-
project_str += " AS '\\2__\\1'"
84+
# Rather than keep the duckdb columns we match them and switch them to our
85+
# preferred format.
86+
project_str = ', '.join(
87+
("COALESCE(" if self.default_0 and column_is_numeric[ec] else "") +
88+
"COLUMNS(" + duckdb_escape_literal('(.*)_'+pivot_char+ec) + ")" +
89+
(", 0)" if self.default_0 and column_is_numeric[ec] else "") +
90+
" AS " + duckdb_escape_literal(ec + ("_" if self.short_names else "__") + r"\1")
91+
for ec in expand_cols
92+
)
8093

8194
if index_cols:
8295
project_str = (", ".join([duckdb_escape_identifier(ic) for ic in index_cols])) + ", " + project_str

0 commit comments

Comments
 (0)