@@ -53,6 +53,25 @@ def _parse_parameter(self):
5353 expression = self ._parse_conjunction () or self ._parse_function () or self ._parse_id_var ()
5454 return self .expression (AssignParameter , this = this , expression = expression , wrapped = wrapped )
5555
56+ def _parse_update (self ):
57+ table_expressions = self ._parse_csv (lambda : self ._parse_table (alias_tokens = self .UPDATE_ALIAS_TOKENS ))
58+ for table_expression in table_expressions [1 :]:
59+ table_expressions [0 ].append ("table_expressions" , table_expression )
60+ while self ._match (sqlglot_parser .TokenType .JOIN ):
61+ join_expression = self ._parse_join (True )
62+ if join_expression :
63+ table_expressions [0 ].append ("join_expressions" , join_expression )
64+ return self .expression (
65+ sqlglot_expressions .Update ,
66+ ** { # type: ignore
67+ "this" : table_expressions [0 ],
68+ "expressions" : self ._match (sqlglot_parser .TokenType .SET ) and self ._parse_csv (self ._parse_equality ),
69+ "from" : self ._parse_from (),
70+ "where" : self ._parse_where (),
71+ "returning" : self ._parse_returning (),
72+ },
73+ )
74+
5675 def _parse_limit (self , this = None , top = False ):
5776 if top or not self ._match (sqlglot_parser .TokenType .LIMIT , False ):
5877 return sqlglot_parser .Parser ._parse_limit (self , this , top )
@@ -95,6 +114,17 @@ def assign_parameter_sql(self, expression):
95114 this = f"{{{ this } }}" if expression .args .get ("wrapped" ) else f"{ this } "
96115 return f"""{ self .PARAMETER_TOKEN } { this } := { self .sql (expression , "expression" )} """
97116
117+ def update_sql (self , expression ):
118+ this = ", " .join ([self .sql (expression , "this" )] + [self .sql (table_expression ) for table_expression in expression .args ["this" ].args ["table_expressions" ]]) \
119+ if expression .args ["this" ].args .get ("table_expressions" ) else self .sql (expression , "this" )
120+ join_sql = self .expressions (expression .args .get ("this" ), "join_expressions" , flat = True , sep = "" )
121+ set_sql = self .expressions (expression , flat = True )
122+ from_sql = self .sql (expression , "from" )
123+ where_sql = self .sql (expression , "where" )
124+ returning = self .sql (expression , "returning" )
125+ sql = f"UPDATE { this } { join_sql } SET { set_sql } { from_sql } { where_sql } { returning } "
126+ return self .prepend_ctes (expression , sql )
127+
98128
99129class Compiler (object ):
100130 ESCAPE_CHARS = ['\\ \\ a' , '\\ \\ b' , '\\ \\ f' , '\\ \\ n' , '\\ \\ r' , '\\ \\ t' , '\\ \\ v' , '\\ \\ 0' ]
@@ -2704,7 +2734,7 @@ def parse(expression):
27042734 "primary_keys" : [],
27052735 }
27062736
2707- def parse_column (self , expression , config , arguments , primary_table ):
2737+ def parse_column (self , expression , config , arguments , primary_table = None ):
27082738 dot_keys , convert_typing_filter = [], None
27092739 if isinstance (expression , sqlglot_expressions .Dot ):
27102740 def parse_dot (dot_expression ):
@@ -3067,6 +3097,9 @@ def optimize_rewrite_multi_select(self, expression, config, arguments, from_expr
30673097 on_expressions = join_table ["on_expressions" ] + join_table ["const_expressions" ]
30683098 if on_expressions :
30693099 sql .append ("ON " + " AND " .join ([self .generate_sql (on_expression ) for on_expression in on_expressions ]))
3100+ if expression .args .get ("joins" ):
3101+ for join_expression in expression .args ["joins" ]:
3102+ sql .append (self .generate_sql (join_expression ))
30703103
30713104 if expression .args .get ("where" ):
30723105 sql .append (self .generate_sql (expression .args ["where" ]))
@@ -3153,6 +3186,9 @@ def resort_join_tables(current_table, on_expressions, join_tables):
31533186 on_expressions = join_table ["on_expressions" ] + join_table ["const_expressions" ] + join_table ["calculate_expressions" ]
31543187 if on_expressions :
31553188 sql .append ("ON " + " AND " .join ([self .generate_sql (on_expression ) for on_expression in on_expressions ]))
3189+ if expression .args .get ("joins" ):
3190+ for join_expression in expression .args ["joins" ]:
3191+ sql .append (self .generate_sql (join_expression ))
31563192
31573193 where_expressions = selected_table ["const_expressions" ] + primary_table ["calculate_expressions" ] + selected_table ["calculate_expressions" ]
31583194 if where_expressions :
@@ -3208,6 +3244,9 @@ def optimize_rewrite_inner_join(self, expression, config, arguments):
32083244 sql .append ("LEFT JOIN " + self .generate_sql (join_expression .args ["this" ]))
32093245 if join_expression .args .get ("on" ):
32103246 sql .append ("ON " + self .generate_sql (join_expression .args ["on" ]))
3247+ if expression .args .get ("joins" ):
3248+ for join_expression in expression .args ["joins" ]:
3249+ sql .append (self .generate_sql (join_expression ))
32113250
32123251 inner_condition_sql = " AND " .join (["%s.%s IS NOT NULL" % (calculate_field ["table_name" ], calculate_field ["column_name" ])
32133252 for calculate_field in inner_calculate_fields ])
@@ -3231,27 +3270,41 @@ def optimize_rewrite_inner_join(self, expression, config, arguments):
32313270 return maybe_parse (" " .join (sql ), dialect = CompilerDialect )
32323271
32333272 def optimize_rewrite_update (self , expression , config , arguments ):
3234- primary_table = self .optimize_rewrite_parse_table (expression , config , arguments , expression .args ["this" ])
3235- if not primary_table ["table_name" ]:
3273+ primary_tables = [ self .optimize_rewrite_parse_table (expression , config , arguments , expression .args ["this" ])]
3274+ if not primary_tables [ 0 ] ["table_name" ]:
32363275 return expression
3237-
3238- primary_keys , set_expressions = [], []
3276+ if expression .args ["this" ].args .get ("table_expressions" ):
3277+ for table_expression in expression .args ["this" ].args ["table_expressions" ]:
3278+ primary_table = self .optimize_rewrite_parse_table (expression , config , arguments , table_expression )
3279+ if not primary_table ["table_name" ]:
3280+ return expression
3281+ primary_tables .append (primary_table )
3282+
3283+ primary_table , primary_keys , set_expressions = (primary_tables [0 ] if len (primary_tables ) == 1 else None ), [], []
32393284 for set_expression in expression .args ["expressions" ]:
32403285 if not isinstance (set_expression , sqlglot_expressions .EQ ):
32413286 raise SyncanySqlCompileException ('error set expression, only supports the = sign assignment operation, related sql "%s"' % self .to_sql (expression ))
32423287 if not isinstance (set_expression .args ["this" ], sqlglot_expressions .Column ):
32433288 raise SyncanySqlCompileException ('error set expression, the assigned item must be a table field, related sql "%s"' % self .to_sql (expression ))
3244- column = self .parse_column (set_expression .args ["this" ], config , arguments , primary_table )
3245- if column ["table_name" ] and column ["table_name" ] != primary_table ["table_name" ]:
3289+ column = self .parse_column (set_expression .args ["this" ], config , arguments , None )
3290+ if primary_table is None :
3291+ for pt in primary_tables :
3292+ if column ["table_name" ] != pt ["table_name" ]:
3293+ continue
3294+ primary_table = pt
3295+ break
3296+ if not primary_table or (column ["table_name" ] and column ["table_name" ] != primary_table ["table_name" ]):
32463297 raise SyncanySqlCompileException ('error set expression, the assigned item must be a table field, related sql "%s"' % self .to_sql (expression ))
32473298 if self .is_column (set_expression .args ["expression" ], config , arguments ):
3248- value_column = self .parse_column (set_expression .args ["expression" ], config , arguments , primary_table )
3299+ value_column = self .parse_column (set_expression .args ["expression" ], config , arguments , None )
32493300 if ((not value_column ["table_name" ] or column ["table_name" ] == value_column ["table_name" ])
32503301 and column ["column_name" ] == value_column ["column_name" ]):
32513302 primary_keys .append (column ["column_name" ])
32523303 continue
32533304 set_expressions .append ({"column" : column , "expression" : set_expression .args ["expression" ]})
32543305
3306+ if not primary_table :
3307+ raise SyncanySqlCompileException ('unknown primary table, related sql "%s"' % self .to_sql (expression ))
32553308 if not primary_keys and not primary_table .get ("primary_keys" ):
32563309 raise SyncanySqlCompileException ('unknown primary key, related sql "%s"' % self .to_sql (expression ))
32573310 sql = ["INSERT INTO" ]
@@ -3286,7 +3339,14 @@ def optimize_rewrite_update(self, expression, config, arguments):
32863339 sql .append ("SELECT" )
32873340 sql .append (", " .join (select_sql ))
32883341 sql .append ("FROM" )
3289- sql .append (self .generate_sql (expression .args ["this" ]))
3342+ if expression .args ["this" ].args .get ("table_expressions" ):
3343+ sql .append (", " .join ([self .generate_sql (expression .args ["this" ])] +
3344+ [self .generate_sql (table_expression ) for table_expression in expression .args ["this" ].args ["table_expressions" ]]))
3345+ else :
3346+ sql .append (self .generate_sql (expression .args ["this" ]))
3347+ if expression .args ["this" ].args .get ("join_expressions" ):
3348+ for join_expression in expression .args ["this" ].args .get ("join_expressions" ):
3349+ sql .append (self .generate_sql (join_expression ))
32903350 if expression .args .get ("where" ):
32913351 sql .append (self .generate_sql (expression .args ["where" ]))
32923352 return maybe_parse (" " .join (sql ), dialect = CompilerDialect )
0 commit comments