Skip to content

Commit a08e0b8

Browse files
committed
Supported MAX/MIN/MAXIF/MINIF
1 parent f04572b commit a08e0b8

6 files changed

Lines changed: 136 additions & 3 deletions

File tree

forms/core/forms.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,27 @@ def compute_formula(self, formula_str: str, num_formulas: int = -1, **kwargs) ->
278278
print(f"An error occurred: {e}")
279279
traceback.print_exception(*sys.exc_info())
280280

281+
def print_sql_strings(self, formula_str: str, num_formulas: int = -1, **kwargs):
282+
try:
283+
root = parse_formula_str(formula_str)
284+
validate(FunctionExecutor.DB_EXECUTOR, self.num_rows, self.num_columns, root)
285+
root = PlanRewriter(self.db_config).rewrite_plan(root)
286+
287+
if num_formulas <= 0:
288+
num_formulas = self.num_rows
289+
exec_context = DBExecContext(
290+
self.connection, self.cursor, self.base_table, START_ROW_ID, START_ROW_ID + num_formulas
291+
)
292+
executor = DBExecutor(self.db_config, exec_context, self.metrics_tracker)
293+
sql_strings = executor.get_sql_strings(root)
294+
executor.clean_up()
295+
296+
for s in sql_strings:
297+
print(s)
298+
except FormSException as e:
299+
print(f"An error occurred: {e}")
300+
traceback.print_exception(*sys.exc_info())
301+
281302
def print_workbook(self, num_rows=10, keep_original_labels=False):
282303
order_by_clause = ", ".join(self.db_config.order_key)
283304
query = f"SELECT * FROM {self.db_config.table_name} ORDER BY {order_by_clause} LIMIT {num_rows}"

forms/executor/dbexecutor/dbexecutor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,21 @@ def __init__(
4949
self.exec_context = exec_context
5050
self.metrics_tracker = metrics_tracker
5151

52+
def get_sql_strings(self, formula_plan: PlanNode) -> list:
53+
exec_tree = from_plan_to_execution_tree(formula_plan, self.exec_context.base_table)
54+
scheduler = Scheduler(exec_tree)
55+
sql_strings = []
56+
exec_subtree = scheduler.next_subtree()
57+
is_root_subtree = not scheduler.has_next_subtree()
58+
intermediate_table_name = (
59+
exec_tree.intermediate_table_name if isinstance(exec_tree, DBFuncExecNode) else ""
60+
)
61+
sql_str = translate(
62+
exec_subtree, self.exec_context, intermediate_table_name, is_root_subtree
63+
).as_string(self.exec_context.conn)
64+
sql_strings.append(sql_str)
65+
return sql_strings
66+
5267
def execute_formula_plan(self, formula_plan: PlanNode) -> pd.DataFrame:
5368
exec_tree = from_plan_to_execution_tree(formula_plan, self.exec_context.base_table)
5469
scheduler = Scheduler(exec_tree, self.db_config.enable_pipelining)

forms/executor/dbexecutor/translation.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,18 @@ def translate_if_function(
242242
subtree: DBFuncExecNode, exec_context: DBExecContext, base_table: Composable
243243
) -> Composable:
244244
children = subtree.children
245+
if isinstance(children[1], DBLitExecNode) and children[1].literal == '"NULL"':
246+
return sql.SQL(
247+
"""CASE
248+
WHEN {condition}
249+
THEN NULL
250+
ELSE {false_result}
251+
END"""
252+
).format(
253+
condition=translate_window_clause(children[0], exec_context, base_table),
254+
false_result=translate_window_clause(children[2], exec_context, base_table),
255+
)
256+
245257
return sql.SQL(
246258
"""CASE
247259
WHEN {condition}
@@ -276,6 +288,14 @@ def translate_aggregate_functions(
276288
function=sql.SQL(subtree.function.value),
277289
agg_column=sql.SQL("+").join(sql.Identifier(col) for col in child.cols),
278290
)
291+
elif subtree.function == Function.MAX or subtree.function == Function.MIN:
292+
row_func = "MAX" if subtree.function == Function.MAX else "MIN"
293+
col_func = "GREATEST" if subtree.function == Function.MAX else "LEAST"
294+
agg_sql = sql.SQL("""{row_func}({col_func}({agg_column}))""").format(
295+
row_func=sql.SQL(row_func),
296+
col_func=sql.SQL(col_func),
297+
agg_column=sql.SQL(",").join(sql.Identifier(col) for col in child.cols),
298+
)
279299
elif subtree.function == Function.COUNT:
280300
agg_sql = sql.SQL("""{num_columns} * {function}({agg_column})""").format(
281301
num_columns=sql.Literal(len(child.cols)),
@@ -326,7 +346,30 @@ def translate_aggregate_if_functions(
326346
)
327347
for i in range(len(input_child.cols))
328348
)
329-
)
349+
)
350+
elif subtree.function == Function.MAXIF or subtree.function == Function.MINIF:
351+
row_func = "MAX" if subtree.function == Function.MAXIF else "MIN"
352+
col_func = "GREATEST" if subtree.function == Function.MAXIF else "LEAST"
353+
agg_sql = sql.SQL("""{row_func}({col_func}({agg_expression}))""").format(
354+
row_func=sql.SQL(row_func),
355+
col_func=sql.SQL(col_func),
356+
agg_expression=sql.SQL(",").join(
357+
sql.SQL(
358+
"""
359+
CASE
360+
WHEN {input_col}{condition}
361+
THEN {output_col}
362+
ELSE NULL
363+
END
364+
"""
365+
).format(
366+
input_col=sql.Identifier(input_cols[i]),
367+
condition=sql.SQL(condition),
368+
output_col=sql.Identifier(output_cols[i]),
369+
)
370+
for i in range(len(input_child.cols))
371+
)
372+
)
330373
elif subtree.function == Function.COUNTIF:
331374
agg_sql = sql.SQL("""SUM({agg_expression})""").format(
332375
agg_expression=sql.SQL("+").join(

forms/utils/functions.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class Function(Enum):
3434
SUMIF = "sumif"
3535
COUNTIF = "countif"
3636
AVERAGEIF = "averageif"
37+
MAXIF = "maxif"
38+
MINIF = "minif"
3739

3840
# Text Functions
3941
CONCAT = "concat"
@@ -234,9 +236,13 @@ def from_function_str(function_str: str) -> Function:
234236
Function.DIVIDE,
235237
# Aggregation Functions
236238
Function.SUM,
239+
Function.MAX,
240+
Function.MIN,
237241
Function.COUNT,
238242
Function.AVG,
239243
Function.SUMIF,
244+
Function.MAXIF,
245+
Function.MINIF,
240246
Function.COUNTIF,
241247
Function.AVERAGEIF,
242248
# Control Function
@@ -253,8 +259,8 @@ def from_function_str(function_str: str) -> Function:
253259
Function.INDEX,
254260
}
255261

256-
DB_AGGREGATE_FUNCTIONS = {Function.SUM, Function.COUNT, Function.AVG}
257-
DB_AGGREGATE_IF_FUNCTIONS = {Function.SUMIF, Function.COUNTIF, Function.AVERAGEIF}
262+
DB_AGGREGATE_FUNCTIONS = {Function.SUM, Function.MAX, Function.MIN, Function.COUNT, Function.AVG}
263+
DB_AGGREGATE_IF_FUNCTIONS = {Function.SUMIF, Function.MAXIF, Function.MINIF, Function.COUNTIF, Function.AVERAGEIF}
258264
DB_CELL_REFERENCE_FUNCTIONS = ARITHMETIC_FUNCTIONS | {Function.IF} | COMPARISON_FUNCTIONS
259265
DB_LOOKUP_FUNCTIONS = {
260266
Function.VLOOKUP,

tests/db_tests/test_formula_to_sql.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,6 @@ def test_if_function(get_wb):
148148
computed_df = wb.compute_formula("=IF(A1 < 3, B1, C1)")
149149
expected_df = pd.DataFrame({"row_id": [1, 2, 3, 4], "A": [2, 2, 4, 5]})
150150
assert np.array_equal(computed_df.values, expected_df.values, equal_nan=True)
151+
152+
153+

tests/db_tests/test_translation.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2022-2023 The FormS Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import os
17+
import pandas as pd
18+
import numpy as np
19+
20+
from forms.core.forms import from_db
21+
22+
23+
@pytest.fixture(scope="module")
24+
def get_wb():
25+
wb = from_db(
26+
host=os.getenv("POSTGRES_HOST"),
27+
port=int(os.getenv("POSTGRES_PORT")),
28+
username=os.getenv("POSTGRES_USER"),
29+
password=os.getenv("POSTGRES_PASSWORD"),
30+
db_name=os.getenv("POSTGRES_DB"),
31+
table_name=os.getenv("POSTGRES_TEST_TABLE"),
32+
primary_key=[os.getenv("POSTGRES_PRIMARY_KEY")],
33+
order_key=[os.getenv("POSTGRES_ORDER_KEY")],
34+
enable_rewriting=False,
35+
)
36+
37+
# Yield the object to be used in tests
38+
yield wb
39+
# Close the DBWorkbook
40+
wb.close()
41+
42+
# Check out the SQL strings
43+
def test_get_sql_strings(get_wb):
44+
wb = get_wb
45+
wb.print_sql_strings('=MAX(B1:C2)')

0 commit comments

Comments
 (0)