from typing import Any, Optional, Union
from ..expressions import ComplexCondition, Condition, Raw
from .base import Query
from .mixins import WhereClauseMixin
[docs]
class UpdateQuery(WhereClauseMixin, Query):
__slots__ = (
"_table",
"_values",
"_wheres",
"_limit",
"_order_bys",
"_joins",
"_dialect",
"_allow_all_rows",
)
def __init__(self, table: str, dialect=None, debug=False):
super().__init__(dialect, debug)
self._table = table
self._values: dict[str, Any] = {}
self._wheres: list[tuple[str, str, Any]] = []
self._limit: Optional[int] = None
self._order_bys: list[str] = []
self._joins: list[tuple[str, str, Optional[str]]] = [] # (type, table, on)
self._allow_all_rows: bool = False
[docs]
def set(self, values: dict[str, Any]) -> "UpdateQuery":
self._values.update(values)
return self
[docs]
def join(
self, table: str, on: Optional[str] = None, join_type: str = "INNER"
) -> "UpdateQuery":
"""Add a JOIN clause (MySQL multi-table UPDATE)."""
self._joins.append((join_type, table, on))
return self
[docs]
def left_join(self, table: str, on: Optional[str] = None) -> "UpdateQuery":
"""Add a LEFT JOIN clause."""
return self.join(table, on, join_type="LEFT")
[docs]
def where(
self,
column: Union[str, Raw, Condition, ComplexCondition],
value: Any = None,
operator: str = "=",
) -> "UpdateQuery":
connector, sql, params = self._build_where_clause(column, value, operator)
self._wheres.append((connector, sql, params))
return self
[docs]
def limit(self, limit: int) -> "UpdateQuery":
self._limit = limit
return self
[docs]
def order_by(self, *columns: str) -> "UpdateQuery":
for col in columns:
direction = "ASC"
if col.startswith("-"):
direction = "DESC"
col = col[1:]
self._order_bys.append(f"{self._dialect.quote(col)} {direction}")
return self
[docs]
def allow_all_rows(self) -> "UpdateQuery":
"""Allow UPDATE without WHERE clause (updates all rows).
This is a safety feature to prevent accidental mass updates.
You must call this method explicitly if you want to update all rows.
Returns:
Self for method chaining
Example:
>>> Q.update("users").set({"active": False}).allow_all_rows().build()
"""
self._allow_all_rows = True
return self
[docs]
def batch_update(self, values: list[dict[str, Any]], key: str) -> "UpdateQuery":
"""
Perform a batch update using CASE WHEN.
Args:
values: List of dictionaries containing values to update.
Each dictionary must contain the key column.
key: The column name to use as the key (e.g., "id").
"""
if not values:
return self
# Validate key exists in all rows
first_keys = values[0].keys()
if key not in first_keys:
raise ValueError(f"Key '{key}' not found in values")
# Collect all IDs
ids = [row[key] for row in values]
# Group values by column
columns_to_update = [k for k in first_keys if k != key]
for col in columns_to_update:
# Build CASE WHEN
case_parts = [f"CASE {self._dialect.quote(key)}"]
case_params = []
for row in values:
case_parts.append(f"WHEN {self._ph} THEN {self._ph}")
case_params.extend([row[key], row[col]])
case_parts.append("END")
# Set column to Raw SQL
self.set({col: Raw(" ".join(case_parts), case_params)})
# Add WHERE IN clause
self.where_in(key, ids)
return self
[docs]
def build(self) -> tuple[str, tuple[Any, ...]]:
if not self._table:
raise ValueError("No table specified")
if not self._values:
raise ValueError("No values to update")
# Safety check: UPDATE without WHERE
if not self._wheres and not self._allow_all_rows:
raise ValueError(
"UPDATE without WHERE clause would affect all rows. "
"If this is intentional, call .allow_all_rows() first."
)
parts: list[str] = []
params: list[Any] = []
ph = self._ph
# CTEs
self._build_ctes(parts, params)
# UPDATE table SET
parts.append("UPDATE ")
parts.append(self._dialect.quote(self._table))
# JOINs (for multi-table UPDATE)
if self._joins:
for type_, table, on in self._joins:
parts.append(f" {type_} JOIN {table}")
if on:
parts.append(f" ON {on}")
parts.append(" SET ")
first = True
for col, val in self._values.items():
if not first:
parts.append(", ")
first = False
parts.append(self._dialect.quote(col))
parts.append(" = ")
# Handle Raw expressions
if isinstance(val, Raw):
parts.append(val.sql)
params.extend(val.params)
# Handle subqueries
elif hasattr(val, "build"):
sub_sql, sub_params = val.build()
parts.append(f"({sub_sql})")
params.extend(sub_params)
# Handle regular values
else:
parts.append(ph)
params.append(val)
# WHERE
if self._wheres:
parts.append(" WHERE ")
for i, (connector, sql, p) in enumerate(self._wheres):
if i > 0:
parts.append(f" {connector} ")
parts.append(sql)
params.extend(p)
# ORDER BY
if self._order_bys:
parts.append(" ORDER BY ")
parts.append(", ".join(self._order_bys))
# LIMIT
if self._limit:
parts.append(f" LIMIT {self._limit}")
sql = "".join(parts)
params_tuple = tuple(params)
self._print_debug(sql, params_tuple)
return sql, params_tuple