#!/usr/bin/env python
# coding: utf-8
"""
Database Session 后端
提供基于 SQLite 数据库的 Session 实现
"""
from multiprocessing import current_process
import sqlite3
import time
import json
from typing import Any, Optional
from .session import Session
[文档]
class DatabaseSession:
"""
数据库 Session 实现
使用 SQLite 数据库作为 Session 存储,支持持久化存储
"""
def __init__(
self,
db_path: str = ":memory:",
table_name: str = "sessions",
expiration_time: int = 3600,
**kwargs
):
"""
初始化 Database Session
Args:
db_path: 数据库文件路径,默认为内存数据库
table_name: Session 表名
expiration_time: 默认过期时间(秒)
**kwargs: 其他配置参数
"""
self._db_path = db_path
self._table_name = table_name
self._expiration_time = expiration_time
self._conn = None
self._create_table()
print(f"调试:数据库路径: {self._db_path}")
def _connect(self):
"""连接数据库"""
if self._conn is None:
self._conn = sqlite3.connect(self._db_path)
self._conn.row_factory = sqlite3.Row
return self._conn
def _create_table(self):
"""创建 Session 表"""
conn = self._connect()
cursor = conn.cursor()
# 只在表不存在时创建表
cursor.execute(f"""
CREATE TABLE IF NOT EXISTS {self._table_name} (
session_id TEXT PRIMARY KEY,
data TEXT NOT NULL,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL
)
""")
conn.commit()
# 调试:打印表结构
cursor.execute(f"PRAGMA table_info({self._table_name})")
columns = cursor.fetchall()
print(f"调试:表 {self._table_name} 的结构:")
for column in columns:
print(f" {column[1]}: {column[2]}")
def _make_key(self, key: str) -> str:
"""生成键"""
return key
[文档]
def put(self, session_id: str, session: Session) -> None:
"""
存储 Session
Args:
session_id: Session ID
session: Session 对象
"""
conn = self._connect()
cursor = conn.cursor()
print(f"调试:存储 Session: {session_id} {session.data}")
# 序列化 Session 数据
try:
data = json.dumps(dict(session))
except Exception as e:
raise ValueError(f"无法序列化 Session 数据: {e}")
created_at = int(time.time())
expires_at = created_at + self._expiration_time
# 插入或更新 Session
cursor.execute(f"""
INSERT OR REPLACE INTO {self._table_name}
(session_id, data, created_at, expires_at)
VALUES (?, ?, ?, ?)
""", (session_id, data, created_at, expires_at))
conn.commit()
cursor = conn.cursor()
ret = cursor.execute(f"""
SELECT * FROM {self._table_name}
WHERE session_id = ?
""", (session_id,))
row = cursor.fetchone()
print(row.keys())
print(f"调试:查询 Session: {session_id} 结果: {row if row else '无'}")
[文档]
def get(self, session_id: str) -> Optional[Session]:
"""
获取 Session
Args:
session_id: Session ID
Returns:
Session 对象,如果不存在或已过期则返回 None
"""
conn = self._connect()
cursor = conn.cursor()
# 查询 Session
cursor.execute(f"""
SELECT data, expires_at FROM {self._table_name}
WHERE session_id = ?
""", (session_id,))
row = cursor.fetchone()
if row is None:
return None
# 检查是否过期
expires_at = row['expires_at']
if int(time.time()) > expires_at:
# 删除过期 Session
self.delete(session_id)
return None
# 反序列化 Session 数据
try:
data = json.loads(row['data'])
except Exception as e:
# 如果反序列化失败,删除损坏的 Session
self.delete(session_id)
return None
# 创建 Session 对象
session = Session(session_id)
session.update(data)
return session
[文档]
def delete(self, session_id: str) -> None:
"""
删除 Session
Args:
session_id: Session ID
"""
conn = self._connect()
cursor = conn.cursor()
cursor.execute(f"""
DELETE FROM {self._table_name}
WHERE session_id = ?
""", (session_id,))
conn.commit()
[文档]
def exists(self, session_id: str) -> bool:
"""
检查 Session 是否存在
Args:
session_id: Session ID
Returns:
Session 是否存在且未过期
"""
conn = self._connect()
cursor = conn.cursor()
# 查询 Session
cursor.execute(f"""
SELECT expires_at FROM {self._table_name}
WHERE session_id = ?
""", (session_id,))
row = cursor.fetchone()
if row is None:
return False
# 检查是否过期
expires_at = row['expires_at']
return int(time.time()) <= expires_at
[文档]
def expire(self, session_id: str, expiration: int) -> bool:
"""
设置 Session 过期时间
Args:
session_id: Session ID
expiration: 过期时间(秒)
Returns:
是否设置成功
"""
conn = self._connect()
cursor = conn.cursor()
# 检查 Session 是否存在
if not self.exists(session_id):
return False
# 更新过期时间
expires_at = int(time.time()) + expiration
cursor.execute(f"""
UPDATE {self._table_name}
SET expires_at = ?
WHERE session_id = ?
""", (expires_at, session_id))
conn.commit()
return True
[文档]
def ttl(self, session_id: str) -> int:
"""
获取 Session 剩余过期时间
Args:
session_id: Session ID
Returns:
剩余过期时间(秒),如果 Session 不存在则返回 -2,已过期则返回 -1
"""
conn = self._connect()
cursor = conn.cursor()
# 查询 Session
cursor.execute(f"""
SELECT expires_at FROM {self._table_name}
WHERE session_id = ?
""", (session_id,))
row = cursor.fetchone()
if row is None:
return -2
# 计算剩余过期时间
expires_at = row['expires_at']
current_time = int(time.time())
ttl = expires_at - current_time
if ttl <= 0:
# 删除过期 Session
self.delete(session_id)
return -1
return ttl
[文档]
def clear(self) -> None:
"""
清空所有 Session
"""
conn = self._connect()
cursor = conn.cursor()
cursor.execute(f"DELETE FROM {self._table_name}")
conn.commit()
def __len__(self) -> int:
"""
获取 Session 数量
Returns:
Session 数量
"""
conn = self._connect()
cursor = conn.cursor()
cursor.execute(f"SELECT COUNT(*) FROM {self._table_name}")
return cursor.fetchone()[0]
[文档]
def close(self) -> None:
"""
关闭数据库连接
"""
if self._conn is not None:
try:
self._conn.close()
except Exception:
pass
self._conn = None
def __enter__(self):
"""支持上下文管理器"""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""支持上下文管理器"""
self.close()
[文档]
def create(self) -> Session:
"""
创建新的 Session 对象
Returns:
Session 对象
"""
session = Session(store=self)
return session
[文档]
def save(self, session: Session) -> None:
"""
保存 Session 数据
Args:
session: Session 对象
"""
self.put(session.id, session)
__all__ = [
'DatabaseSession',
]