From 75fdbd3077c3a461fbc4fbd0ca11fe6118ca69c0 Mon Sep 17 00:00:00 2001 From: Evert Lammerts Date: Wed, 14 Jan 2026 13:02:22 +0100 Subject: [PATCH] escape idenitifiers in relation aggregations --- src/duckdb_py/pyrelation.cpp | 32 +++- .../relational_api/test_rapi_aggregations.py | 143 ++++++++++++++++++ 2 files changed, 172 insertions(+), 3 deletions(-) diff --git a/src/duckdb_py/pyrelation.cpp b/src/duckdb_py/pyrelation.cpp index 58cfcc29..2e748f8f 100644 --- a/src/duckdb_py/pyrelation.cpp +++ b/src/duckdb_py/pyrelation.cpp @@ -395,10 +395,36 @@ string DuckDBPyRelation::GenerateExpressionList(const string &function_name, vec function_name + "(" + function_parameter + ((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec; } for (idx_t i = 0; i < input.size(); i++) { + // We parse the input as an expression to validate it. + auto trimmed_input = input[i]; + StringUtil::Trim(trimmed_input); + + unique_ptr expression; + try { + auto expressions = Parser::ParseExpressionList(trimmed_input); + if (expressions.size() == 1) { + expression = std::move(expressions[0]); + } + } catch (const ParserException &) { + // First attempt at parsing failed, the input might be a column name that needs quoting. + auto quoted_input = KeywordHelper::WriteQuoted(trimmed_input, '"'); + auto expressions = Parser::ParseExpressionList(quoted_input); + if (expressions.size() == 1 && expressions[0]->GetExpressionClass() == ExpressionClass::COLUMN_REF) { + expression = std::move(expressions[0]); + } + } + + if (!expression) { + throw ParserException("Invalid column expression: %s", trimmed_input); + } + + // ToString() handles escaping for all expression types + auto escaped_input = expression->ToString(); + if (function_parameter.empty()) { - expr += function_name + "(" + input[i] + ((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec; + expr += function_name + "(" + escaped_input + ((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec; } else { - expr += function_name + "(" + input[i] + "," + function_parameter + + expr += function_name + "(" + escaped_input + "," + function_parameter + ((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec; } @@ -587,7 +613,7 @@ unique_ptr DuckDBPyRelation::Product(const std::string &column unique_ptr DuckDBPyRelation::StringAgg(const std::string &column, const std::string &sep, const std::string &groups, const std::string &window_spec, const std::string &projected_columns) { - auto string_agg_params = "\'" + sep + "\'"; + auto string_agg_params = KeywordHelper::WriteOptionallyQuoted(sep, '\''); return ApplyAggOrWin("string_agg", column, string_agg_params, groups, window_spec, projected_columns); } diff --git a/tests/fast/relational_api/test_rapi_aggregations.py b/tests/fast/relational_api/test_rapi_aggregations.py index ffb7e303..409972fc 100644 --- a/tests/fast/relational_api/test_rapi_aggregations.py +++ b/tests/fast/relational_api/test_rapi_aggregations.py @@ -416,3 +416,146 @@ def test_var_samp(self, table, f): def test_describe(self, table): assert table.describe().fetchall() is not None + + +class TestRAPIAggregationsColumnEscaping: + """Test that aggregate functions properly escape column names that need quoting.""" + + def test_reserved_keyword_column_name(self, duckdb_cursor): + # Column name "select" is a reserved SQL keyword + rel = duckdb_cursor.sql('select 1 as "select", 2 as "order"') + result = rel.sum("select").fetchall() + assert result == [(1,)] + + result = rel.avg("order").fetchall() + assert result == [(2.0,)] + + def test_column_name_with_space(self, duckdb_cursor): + rel = duckdb_cursor.sql('select 10 as "my column"') + result = rel.sum("my column").fetchall() + assert result == [(10,)] + + def test_column_name_with_quotes(self, duckdb_cursor): + # Column name containing a double quote + rel = duckdb_cursor.sql('select 5 as "col""name"') + result = rel.sum('col"name').fetchall() + assert result == [(5,)] + + def test_qualified_column_name(self, duckdb_cursor): + # Qualified column name like table.column + rel = duckdb_cursor.sql("select 42 as value") + # When using qualified names, they should be properly escaped + result = rel.sum("value").fetchall() + assert result == [(42,)] + + +class TestRAPIAggregationsExpressionPassthrough: + """Test that aggregate functions correctly pass through SQL expressions without escaping.""" + + def test_cast_expression(self, duckdb_cursor): + # Cast expressions should pass through without being quoted + rel = duckdb_cursor.sql("select 1 as v, 0 as f") + result = rel.bool_and("v::BOOL").fetchall() + assert result == [(True,)] + + result = rel.bool_or("f::BOOL").fetchall() + assert result == [(False,)] + + def test_star_expression(self, duckdb_cursor): + # Star (*) should pass through for count + rel = duckdb_cursor.sql("select 1 as a union all select 2") + result = rel.count("*").fetchall() + assert result == [(2,)] + + def test_arithmetic_expression(self, duckdb_cursor): + # Arithmetic expressions should pass through + rel = duckdb_cursor.sql("select 10 as a, 5 as b") + result = rel.sum("a + b").fetchall() + assert result == [(15,)] + + def test_function_expression(self, duckdb_cursor): + # Function calls should pass through + rel = duckdb_cursor.sql("select -5 as v") + result = rel.sum("abs(v)").fetchall() + assert result == [(5,)] + + def test_case_expression(self, duckdb_cursor): + # CASE expressions should pass through + rel = duckdb_cursor.sql("select 1 as v union all select 2 union all select 3") + result = rel.sum("case when v > 1 then v else 0 end").fetchall() + assert result == [(5,)] + + +class TestRAPIAggregationsWithInvalidInput: + """Test that only expression can be used.""" + + def test_injection_with_semicolon_is_neutralized(self, duckdb_cursor): + # Semicolon injection fails to parse as expression, gets quoted as identifier + rel = duckdb_cursor.sql("select 1 as v") + with pytest.raises(duckdb.BinderException, match="not found in FROM clause"): + rel.sum("v; drop table agg; --").fetchall() + + def test_injection_with_union_is_neutralized(self, duckdb_cursor): + # UNION fails to parse as single expression, gets quoted + rel = duckdb_cursor.sql("select 1 as v") + with pytest.raises(duckdb.BinderException, match="not found in FROM clause"): + rel.sum("v union select * from agg").fetchall() + + def test_subquery_is_contained(self, duckdb_cursor): + # Subqueries are valid expressions - they're contained within the aggregate + # and cannot break out of the expression context + rel = duckdb_cursor.sql("select 1 as v") + # This executes sum((select 1)) = sum(1) = 1 - contained, not an injection + result = rel.sum("(select 1)").fetchall() + assert result == [(1,)] + + def test_injection_closing_paren_is_neutralized(self, duckdb_cursor): + # Adding a closing paren fails to parse, gets quoted + rel = duckdb_cursor.sql("select 1 as v") + with pytest.raises(duckdb.BinderException, match="not found in FROM clause"): + rel.sum("v) from agg; drop table agg; --").fetchall() + + def test_comment_is_harmless(self, duckdb_cursor): + # SQL comments are stripped during parsing, so "v -- comment" parses as just "v" + rel = duckdb_cursor.sql("select 1 as v") + result = rel.sum("v -- this is ignored").fetchall() + assert result == [(1,)] + + def test_empty_expression_rejected(self, duckdb_cursor): + # Empty or whitespace-only expressions should be rejected + rel = duckdb_cursor.sql("select 1 as v") + with pytest.raises(duckdb.ParserException): + rel.sum("").fetchall() + + def test_whitespace_only_expression_rejected(self, duckdb_cursor): + # Whitespace-only expressions should be rejected + rel = duckdb_cursor.sql("select 1 as v") + with pytest.raises(duckdb.ParserException): + rel.sum(" ").fetchall() + + +class TestRAPIStringAggSeparatorEscaping: + """Test that string_agg separator is properly escaped as a string literal.""" + + def test_simple_separator(self, duckdb_cursor): + rel = duckdb_cursor.sql("select 'a' as s union all select 'b' union all select 'c'") + result = rel.string_agg("s", ",").fetchall() + assert result == [("a,b,c",)] + + def test_separator_with_single_quote(self, duckdb_cursor): + # Separator containing a single quote should be properly escaped + rel = duckdb_cursor.sql("select 'a' as s union all select 'b'") + result = rel.string_agg("s", "','").fetchall() + assert result == [("a','b",)] + + def test_separator_with_special_chars(self, duckdb_cursor): + rel = duckdb_cursor.sql("select 'x' as s union all select 'y'") + result = rel.string_agg("s", " | ").fetchall() + assert result == [("x | y",)] + + def test_separator_injection_attempt(self, duckdb_cursor): + # Attempt to inject via separator - should be safely quoted as string literal + rel = duckdb_cursor.sql("select 'a' as s union all select 'b'") + # This should NOT execute the injection - separator becomes a literal string + result = rel.string_agg("s", "'); drop table agg; --").fetchall() + assert result == [("a'); drop table agg; --b",)]