from typing import Any, Optional, Union
from ..expressions import ComplexCondition, Condition, Raw
from .base import Query
from .mixins import WhereClauseMixin
[docs]
class UpdateQuery(WhereClauseMixin, Query):
__slots__ = (
"_table",
"_values",
"_wheres",
"_limit",
"_order_bys",
"_joins",
"_dialect",
"_allow_all_rows",
)
def __init__(self, table: str, dialect=None):
super().__init__(dialect)
self._table = table
self._values: dict[str, Any] = {}
self._wheres: list[tuple[str, str, Any]] = []
self._limit: Optional[int] = None
self._order_bys: list[str] = []
self._joins: list[tuple[str, str, Optional[str]]] = [] # (type, table, on)
self._allow_all_rows: bool = False
[docs]
def set(self, values: dict[str, Any]) -> "UpdateQuery":
self._values.update(values)
return self
[docs]
def join(
self, table: str, on: Optional[str] = None, join_type: str = "INNER"
) -> "UpdateQuery":
"""Add a JOIN clause (MySQL multi-table UPDATE)."""
self._joins.append((join_type, table, on))
return self
[docs]
def left_join(self, table: str, on: Optional[str] = None) -> "UpdateQuery":
"""Add a LEFT JOIN clause."""
return self.join(table, on, join_type="LEFT")
[docs]
def where(
self,
column: Union[str, Raw, Condition, ComplexCondition],
value: Any = None,
operator: str = "=",
) -> "UpdateQuery":
connector, sql, params = self._build_where_clause(column, value, operator)
self._wheres.append((connector, sql, params))
return self
[docs]
def limit(self, limit: int) -> "UpdateQuery":
self._limit = limit
return self
[docs]
def order_by(self, *columns: str) -> "UpdateQuery":
for col in columns:
direction = "ASC"
if col.startswith("-"):
direction = "DESC"
col = col[1:]
self._order_bys.append(f"{self._dialect.quote(col)} {direction}")
return self
[docs]
def allow_all_rows(self) -> "UpdateQuery":
"""Allow UPDATE without WHERE clause (updates all rows).
This is a safety feature to prevent accidental mass updates.
You must call this method explicitly if you want to update all rows.
Returns:
Self for method chaining
Example:
>>> Q.update("users").set({"active": False}).allow_all_rows().build()
"""
self._allow_all_rows = True
return self
[docs]
def build(self) -> tuple[str, tuple[Any, ...]]:
if not self._table:
raise ValueError("No table specified")
if not self._values:
raise ValueError("No values to update")
# Safety check: UPDATE without WHERE
if not self._wheres and not self._allow_all_rows:
raise ValueError(
"UPDATE without WHERE clause would affect all rows. "
"If this is intentional, call .allow_all_rows() first."
)
parts: list[str] = []
params: list[Any] = []
ph = self._dialect.parameter_placeholder()
# UPDATE table SET
parts.append("UPDATE ")
parts.append(self._dialect.quote(self._table))
# JOINs (for multi-table UPDATE)
if self._joins:
for type_, table, on in self._joins:
parts.append(f" {type_} JOIN {table}")
if on:
parts.append(f" ON {on}")
parts.append(" SET ")
first = True
for col, val in self._values.items():
if not first:
parts.append(", ")
first = False
parts.append(self._dialect.quote(col))
parts.append(" = ")
# Handle Raw expressions
if isinstance(val, Raw):
parts.append(val.sql)
params.extend(val.params)
# Handle subqueries
elif hasattr(val, "build"):
sub_sql, sub_params = val.build()
parts.append(f"({sub_sql})")
params.extend(sub_params)
# Handle regular values
else:
parts.append(ph)
params.append(val)
# WHERE
if self._wheres:
parts.append(" WHERE ")
for i, (connector, sql, p) in enumerate(self._wheres):
if i > 0:
parts.append(f" {connector} ")
parts.append(sql)
params.extend(p)
# ORDER BY
if self._order_bys:
parts.append(" ORDER BY ")
parts.append(", ".join(self._order_bys))
# LIMIT
if self._limit:
parts.append(f" LIMIT {self._limit}")
return "".join(parts), tuple(params)