Skip to content

Commit a7ac59e

Browse files
committed
Optimize update to support multi-table updates
1 parent 36655b9 commit a7ac59e

1 file changed

Lines changed: 69 additions & 9 deletions

File tree

syncanysql/compiler.py

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,25 @@ def _parse_parameter(self):
5353
expression = self._parse_conjunction() or self._parse_function() or self._parse_id_var()
5454
return self.expression(AssignParameter, this=this, expression=expression, wrapped=wrapped)
5555

56+
def _parse_update(self):
57+
table_expressions = self._parse_csv(lambda: self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS))
58+
for table_expression in table_expressions[1:]:
59+
table_expressions[0].append("table_expressions", table_expression)
60+
while self._match(sqlglot_parser.TokenType.JOIN):
61+
join_expression = self._parse_join(True)
62+
if join_expression:
63+
table_expressions[0].append("join_expressions", join_expression)
64+
return self.expression(
65+
sqlglot_expressions.Update,
66+
**{ # type: ignore
67+
"this": table_expressions[0],
68+
"expressions": self._match(sqlglot_parser.TokenType.SET) and self._parse_csv(self._parse_equality),
69+
"from": self._parse_from(),
70+
"where": self._parse_where(),
71+
"returning": self._parse_returning(),
72+
},
73+
)
74+
5675
def _parse_limit(self, this=None, top=False):
5776
if top or not self._match(sqlglot_parser.TokenType.LIMIT, False):
5877
return sqlglot_parser.Parser._parse_limit(self, this, top)
@@ -95,6 +114,17 @@ def assign_parameter_sql(self, expression):
95114
this = f"{{{this}}}" if expression.args.get("wrapped") else f"{this}"
96115
return f"""{self.PARAMETER_TOKEN}{this} := {self.sql(expression, "expression")}"""
97116

117+
def update_sql(self, expression):
118+
this = ", ".join([self.sql(expression, "this")] + [self.sql(table_expression) for table_expression in expression.args["this"].args["table_expressions"]]) \
119+
if expression.args["this"].args.get("table_expressions") else self.sql(expression, "this")
120+
join_sql = self.expressions(expression.args.get("this"), "join_expressions", flat=True, sep="")
121+
set_sql = self.expressions(expression, flat=True)
122+
from_sql = self.sql(expression, "from")
123+
where_sql = self.sql(expression, "where")
124+
returning = self.sql(expression, "returning")
125+
sql = f"UPDATE {this} {join_sql} SET {set_sql}{from_sql}{where_sql}{returning}"
126+
return self.prepend_ctes(expression, sql)
127+
98128

99129
class Compiler(object):
100130
ESCAPE_CHARS = ['\\\\a', '\\\\b', '\\\\f', '\\\\n', '\\\\r', '\\\\t', '\\\\v', '\\\\0']
@@ -2704,7 +2734,7 @@ def parse(expression):
27042734
"primary_keys": [],
27052735
}
27062736

2707-
def parse_column(self, expression, config, arguments, primary_table):
2737+
def parse_column(self, expression, config, arguments, primary_table=None):
27082738
dot_keys, convert_typing_filter = [], None
27092739
if isinstance(expression, sqlglot_expressions.Dot):
27102740
def parse_dot(dot_expression):
@@ -3067,6 +3097,9 @@ def optimize_rewrite_multi_select(self, expression, config, arguments, from_expr
30673097
on_expressions = join_table["on_expressions"] + join_table["const_expressions"]
30683098
if on_expressions:
30693099
sql.append("ON " + " AND ".join([self.generate_sql(on_expression) for on_expression in on_expressions]))
3100+
if expression.args.get("joins"):
3101+
for join_expression in expression.args["joins"]:
3102+
sql.append(self.generate_sql(join_expression))
30703103

30713104
if expression.args.get("where"):
30723105
sql.append(self.generate_sql(expression.args["where"]))
@@ -3153,6 +3186,9 @@ def resort_join_tables(current_table, on_expressions, join_tables):
31533186
on_expressions = join_table["on_expressions"] + join_table["const_expressions"] + join_table["calculate_expressions"]
31543187
if on_expressions:
31553188
sql.append("ON " + " AND ".join([self.generate_sql(on_expression) for on_expression in on_expressions]))
3189+
if expression.args.get("joins"):
3190+
for join_expression in expression.args["joins"]:
3191+
sql.append(self.generate_sql(join_expression))
31563192

31573193
where_expressions = selected_table["const_expressions"] + primary_table["calculate_expressions"] + selected_table["calculate_expressions"]
31583194
if where_expressions:
@@ -3208,6 +3244,9 @@ def optimize_rewrite_inner_join(self, expression, config, arguments):
32083244
sql.append("LEFT JOIN " + self.generate_sql(join_expression.args["this"]))
32093245
if join_expression.args.get("on"):
32103246
sql.append("ON " + self.generate_sql(join_expression.args["on"]))
3247+
if expression.args.get("joins"):
3248+
for join_expression in expression.args["joins"]:
3249+
sql.append(self.generate_sql(join_expression))
32113250

32123251
inner_condition_sql = " AND ".join(["%s.%s IS NOT NULL" % (calculate_field["table_name"], calculate_field["column_name"])
32133252
for calculate_field in inner_calculate_fields])
@@ -3231,27 +3270,41 @@ def optimize_rewrite_inner_join(self, expression, config, arguments):
32313270
return maybe_parse(" ".join(sql), dialect=CompilerDialect)
32323271

32333272
def optimize_rewrite_update(self, expression, config, arguments):
3234-
primary_table = self.optimize_rewrite_parse_table(expression, config, arguments, expression.args["this"])
3235-
if not primary_table["table_name"]:
3273+
primary_tables = [self.optimize_rewrite_parse_table(expression, config, arguments, expression.args["this"])]
3274+
if not primary_tables[0]["table_name"]:
32363275
return expression
3237-
3238-
primary_keys, set_expressions = [], []
3276+
if expression.args["this"].args.get("table_expressions"):
3277+
for table_expression in expression.args["this"].args["table_expressions"]:
3278+
primary_table = self.optimize_rewrite_parse_table(expression, config, arguments, table_expression)
3279+
if not primary_table["table_name"]:
3280+
return expression
3281+
primary_tables.append(primary_table)
3282+
3283+
primary_table, primary_keys, set_expressions = (primary_tables[0] if len(primary_tables) == 1 else None), [], []
32393284
for set_expression in expression.args["expressions"]:
32403285
if not isinstance(set_expression, sqlglot_expressions.EQ):
32413286
raise SyncanySqlCompileException('error set expression, only supports the = sign assignment operation, related sql "%s"' % self.to_sql(expression))
32423287
if not isinstance(set_expression.args["this"], sqlglot_expressions.Column):
32433288
raise SyncanySqlCompileException('error set expression, the assigned item must be a table field, related sql "%s"' % self.to_sql(expression))
3244-
column = self.parse_column(set_expression.args["this"], config, arguments, primary_table)
3245-
if column["table_name"] and column["table_name"] != primary_table["table_name"]:
3289+
column = self.parse_column(set_expression.args["this"], config, arguments, None)
3290+
if primary_table is None:
3291+
for pt in primary_tables:
3292+
if column["table_name"] != pt["table_name"]:
3293+
continue
3294+
primary_table = pt
3295+
break
3296+
if not primary_table or (column["table_name"] and column["table_name"] != primary_table["table_name"]):
32463297
raise SyncanySqlCompileException('error set expression, the assigned item must be a table field, related sql "%s"' % self.to_sql(expression))
32473298
if self.is_column(set_expression.args["expression"], config, arguments):
3248-
value_column = self.parse_column(set_expression.args["expression"], config, arguments, primary_table)
3299+
value_column = self.parse_column(set_expression.args["expression"], config, arguments, None)
32493300
if ((not value_column["table_name"] or column["table_name"] == value_column["table_name"])
32503301
and column["column_name"] == value_column["column_name"]):
32513302
primary_keys.append(column["column_name"])
32523303
continue
32533304
set_expressions.append({"column": column, "expression": set_expression.args["expression"]})
32543305

3306+
if not primary_table:
3307+
raise SyncanySqlCompileException('unknown primary table, related sql "%s"' % self.to_sql(expression))
32553308
if not primary_keys and not primary_table.get("primary_keys"):
32563309
raise SyncanySqlCompileException('unknown primary key, related sql "%s"' % self.to_sql(expression))
32573310
sql = ["INSERT INTO"]
@@ -3286,7 +3339,14 @@ def optimize_rewrite_update(self, expression, config, arguments):
32863339
sql.append("SELECT")
32873340
sql.append(", ".join(select_sql))
32883341
sql.append("FROM")
3289-
sql.append(self.generate_sql(expression.args["this"]))
3342+
if expression.args["this"].args.get("table_expressions"):
3343+
sql.append(", ".join([self.generate_sql(expression.args["this"])] +
3344+
[self.generate_sql(table_expression) for table_expression in expression.args["this"].args["table_expressions"]]))
3345+
else:
3346+
sql.append(self.generate_sql(expression.args["this"]))
3347+
if expression.args["this"].args.get("join_expressions"):
3348+
for join_expression in expression.args["this"].args.get("join_expressions"):
3349+
sql.append(self.generate_sql(join_expression))
32903350
if expression.args.get("where"):
32913351
sql.append(self.generate_sql(expression.args["where"]))
32923352
return maybe_parse(" ".join(sql), dialect=CompilerDialect)

0 commit comments

Comments
 (0)