litefs.core 源代码

#!/usr/bin/env python
# coding: utf-8

import argparse
import logging
import sys
import socket
import time
from datetime import datetime
from posixpath import abspath as path_abspath
from typing import Optional

from watchdog.observers import Observer

from .cache import (
    CacheBackend,
    CacheFactory,
    CacheManager,
    FileEventHandler,
    MemoryCache,
    TreeCache,
)
from .config import Config, load_config
from .database import DatabaseManager
from .error_pages import ErrorPageRenderer
from .handlers import RequestHandler, WSGIRequestHandler
from .middleware import MiddlewareManager
from .routing import Router
from .server import (
    DEFAULT_BUFFER_SIZE,
    BufferedRWPair,
    HTTPServer,
    ProcessHTTPServer,
    SocketIO,
    mainloop,
)
from .utils import log_error, log_info, make_logger

from ._version import __version__
from .plugins import PluginManager, PluginLoader


[文档] def make_config(**kwargs): """ 创建配置对象 支持多种配置来源: 1. 默认配置 2. 配置文件(通过 config_file 参数) 3. 环境变量(LITEFS_*) 4. 代码中的配置(kwargs) Args: **kwargs: 配置项 Returns: Config 对象 """ config_file = kwargs.pop('config_file', None) config = load_config(config_file=config_file, **kwargs) return config
[文档] def make_server(host, port, request_size=-1): import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((host, port)) if -1 == request_size: request_size = 1024 sock.listen(request_size) sock.setblocking(False) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) return sock
[文档] def is_port_available(host, port): """ 检查端口是否可用 Args: host: 主机地址 port: 端口号 Returns: bool: 端口是否可用 """ try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((host, port)) sock.close() return True except OSError: return False
[文档] class Litefs(object): def __init__(self, **kwargs): self.config = config = make_config(**kwargs) level = logging.DEBUG if config.debug else logging.INFO self.logger = make_logger(__name__, log=config.log, level=level) # 禁止 watchdog 的 DEBUG 日志输出 watchdog_logger = logging.getLogger('watchdog') watchdog_logger.setLevel(logging.INFO) watchdog_observers_logger = logging.getLogger('watchdog.observers') watchdog_observers_logger.setLevel(logging.INFO) self.host = config.host self.port = config.port self.server = None # 使用全局 Session 管理器,确保 Session 对象常驻内存 # 不会因为 Litefs 实例的创建和销毁而丢失数据 from litefs.session.manager import SessionManager from litefs.session.factory import SessionBackend # 获取 session 相关配置 session_backend = getattr(config, 'session_backend', SessionBackend.MEMORY) session_config = {} if session_backend == SessionBackend.REDIS: session_config = { "host": getattr(config, "redis_host", "localhost"), "port": getattr(config, "redis_port", 6379), "db": getattr(config, "redis_db", 0), "password": getattr(config, "redis_password", None), "key_prefix": getattr(config, "redis_session_key_prefix", "litefs:session:"), "expiration_time": getattr(config, "session_expiration_time", 3600), } elif session_backend == SessionBackend.DATABASE: session_config = { "db_path": getattr(config, "database_path", ":memory:"), "table_name": getattr(config, "database_session_table", "sessions"), "expiration_time": getattr(config, "session_expiration_time", 3600), } elif session_backend == SessionBackend.MEMCACHE: session_config = { "servers": getattr(config, "memcache_servers", ["localhost:11211"]), "key_prefix": getattr(config, "memcache_session_key_prefix", "litefs:session:"), "expiration_time": getattr(config, "session_expiration_time", 3600), } elif session_backend == SessionBackend.MEMORY: session_config = { "max_size": getattr(config, "session_max_size", 1000000), } # 使用 SessionManager 管理 Session 实例 self.sessions = SessionManager.get_session_cache( backend=session_backend, cache_key='default', **session_config ) self.caches = CacheManager.get_cache( backend=getattr(config, 'cache_backend', CacheBackend.TREE), cache_key='app_cache', max_size=getattr(config, 'cache_max_size', 10000), clean_period=getattr(config, 'cache_clean_period', 60), expiration_time=getattr(config, 'cache_expiration_time', 3600), ) self.middleware_manager = MiddlewareManager() self._middleware_instances = [] error_pages_dir = getattr(config, "error_pages_dir", None) self.error_page_renderer = ErrorPageRenderer(error_pages_dir) # 初始化路由管理器 self.router = Router() # 初始化数据库管理器 self.db_manager = DatabaseManager() self.db = self.db_manager.get_database(self.config) # 初始化插件系统 self.plugin_manager = PluginManager(self) self.plugin_loader = PluginLoader(self) # 添加默认插件目录 self.plugin_loader.add_plugin_dir('./plugins') self.plugin_loader.add_plugin_dir('./litefs/plugins')
[文档] def database(self, name: str = 'default'): """ 获取数据库实例 Args: name: 数据库名称 Returns: 数据库实例 """ return self.db_manager.get_database(self.config, name)
[文档] def db_session(self, name: str = 'default'): """ 获取数据库会话 Args: name: 数据库名称 Returns: 数据库会话实例 """ return self.db_manager.get_session(name)
[文档] def session(self, name: str = 'default'): """ 获取数据库会话(别名,已废弃,建议使用 db_session) Args: name: 数据库名称 Returns: 数据库会话实例 """ import warnings warnings.warn('session() method is deprecated, use db_session() instead', DeprecationWarning) return self.db_session(name)
[文档] def create_all_tables(self, name: str = 'default'): """ 创建所有数据表 Args: name: 数据库名称 """ self.db_manager.create_all(name)
[文档] def drop_all_tables(self, name: str = 'default'): """ 删除所有数据表 Args: name: 数据库名称 """ self.db_manager.drop_all(name)
[文档] def wsgi(self): """ 返回符合 PEP 3333 规范的 WSGI application callable 用法: import litefs app = litefs.Litefs() application = app.wsgi() 在 gunicorn 中使用: gunicorn -w 4 -b :8000 wsgi_example:application 在 uWSGI 中使用: uwsgi --http :8000 --wsgi-file wsgi_example.py """ def application(environ, start_response): """ WSGI application callable Args: environ: WSGI 环境变量字典 start_response: 开始响应的 callable Returns: 可迭代的 bytes """ try: request_handler = WSGIRequestHandler(self, environ) middleware_result = self.middleware_manager.process_request(request_handler) if middleware_result is not None: if isinstance(middleware_result, (list, tuple)) and len(middleware_result) == 3: status, headers, content = middleware_result start_response(status, headers) else: content = middleware_result status, headers = "200 OK", [("Content-Type", "text/plain; charset=utf-8")] start_response(status, headers) if isinstance(content, (str, bytes, dict, list, tuple, type(None))): if isinstance(content, str): return [content.encode("utf-8")] elif isinstance(content, bytes): return [content] elif isinstance(content, dict): import json content = json.dumps(content, ensure_ascii=False) return [content.encode("utf-8")] elif isinstance(content, (list, tuple)): import json content = json.dumps(content, ensure_ascii=False, default=str) return [content.encode("utf-8")] else: return [b""] else: return [str(content).encode("utf-8")] handler_result = request_handler.handler() if ( isinstance(handler_result, (list, tuple)) and len(handler_result) == 3 and isinstance(handler_result[0], str) and isinstance(handler_result[1], list) ): status, headers, content = handler_result start_response(status, headers) else: content = handler_result status, headers = "200 OK", [("Content-Type", "text/plain; charset=utf-8")] start_response(status, headers) headers_dict = dict(headers) content_type = headers_dict.get("Content-Type", "") is_json = "application/json" in content_type from collections.abc import Iterable if not isinstance( content, (str, bytes, dict, list, tuple, type(None)) ) and isinstance(content, Iterable): def content_generator(): for item in content: if isinstance(item, str): yield item.encode("utf-8") elif isinstance(item, bytes): yield item else: yield str(item).encode("utf-8") return content_generator() elif isinstance(content, dict): import json content = json.dumps(content, ensure_ascii=False) return [content.encode("utf-8")] elif isinstance(content, (list, tuple)): if is_json: import json content = json.dumps(content, ensure_ascii=False, default=str) return [content.encode("utf-8")] else: result = [] for item in content: if isinstance(item, str): result.append(item.encode("utf-8")) elif isinstance(item, bytes): result.append(item) else: result.append(str(item).encode("utf-8")) return result elif isinstance(content, str): if is_json: import json content = json.dumps(content, ensure_ascii=False) return [content.encode("utf-8")] return [content.encode("utf-8")] elif isinstance(content, bytes): return [content] else: return [str(content).encode("utf-8")] except Exception as e: middleware_result = self.middleware_manager.process_exception(request_handler, e) if middleware_result is not None: if isinstance(middleware_result, (list, tuple)) and len(middleware_result) == 3: status, headers, content = middleware_result start_response(status, headers) else: content = middleware_result status, headers = "200 OK", [("Content-Type", "text/plain; charset=utf-8")] start_response(status, headers) if isinstance(content, str): return [content.encode("utf-8")] elif isinstance(content, bytes): return [content] else: return [str(content).encode("utf-8")] from .exceptions import HttpError if isinstance(e, HttpError): status_code = e.args[0] if len(e.args) > 0 else 500 message = e.args[1] if len(e.args) > 1 else str(e) status = f"{status_code} {message}" headers = [("Content-Type", "text/plain; charset=utf-8")] start_response(status, headers) return [message.encode("utf-8")] else: log_error(self.logger, str(e)) status = "500 Internal Server Error" headers = [("Content-Type", "text/plain; charset=utf-8")] start_response(status, headers) return [b"500 Internal Server Error"] return application
[文档] def add_middleware(self, middleware_class, **kwargs): """ 添加中间件 Args: middleware_class: 中间件类 **kwargs: 传递给中间件构造函数的参数 Returns: self: 支持链式调用 """ self.middleware_manager.add(middleware_class, **kwargs) self._middleware_instances = None return self
[文档] def remove_middleware(self, middleware_class): """ 移除中间件 Args: middleware_class: 中间件类 """ self.middleware_manager.remove(middleware_class) self._middleware_instances = None
[文档] def clear_middleware(self): """ 清空所有中间件 """ self.middleware_manager.clear() self._middleware_instances = None
[文档] def add_route(self, path, methods=None, handler=None, name=None): """ 添加路由 Args: path: 路由路径 methods: HTTP 方法列表,默认 ['GET'] handler: 处理函数 name: 路由名称 """ if methods is None: methods = ['GET'] # 支持装饰器风格调用 if handler is None: def decorator(func): self.router.add_route(path, methods, func, name) return func return decorator else: self.router.add_route(path, methods, handler, name) return self
[文档] def add_get(self, path, handler=None, name=None): """ 添加 GET 方法路由 Args: path: 路由路径 handler: 处理函数 name: 路由名称 """ # 支持装饰器风格调用 if handler is None: def decorator(func): self.router.add_get(path, func, name) return func return decorator else: self.router.add_get(path, handler, name) return self
[文档] def add_post(self, path, handler=None, name=None): """ 添加 POST 方法路由 Args: path: 路由路径 handler: 处理函数 name: 路由名称 """ # 支持装饰器风格调用 if handler is None: def decorator(func): self.router.add_post(path, func, name) return func return decorator else: self.router.add_post(path, handler, name) return self
[文档] def add_put(self, path, handler=None, name=None): """ 添加 PUT 方法路由 Args: path: 路由路径 handler: 处理函数 name: 路由名称 """ # 支持装饰器风格调用 if handler is None: def decorator(func): self.router.add_put(path, func, name) return func return decorator else: self.router.add_put(path, handler, name) return self
[文档] def add_delete(self, path, handler=None, name=None): """ 添加 DELETE 方法路由 Args: path: 路由路径 handler: 处理函数 name: 路由名称 """ # 支持装饰器风格调用 if handler is None: def decorator(func): self.router.add_delete(path, func, name) return func return decorator else: self.router.add_delete(path, handler, name) return self
[文档] def add_patch(self, path, handler=None, name=None): """ 添加 PATCH 方法路由 Args: path: 路由路径 handler: 处理函数 name: 路由名称 """ # 支持装饰器风格调用 if handler is None: def decorator(func): self.router.add_patch(path, func, name) return func return decorator else: self.router.add_patch(path, handler, name) return self
[文档] def add_options(self, path, handler=None, name=None): """ 添加 OPTIONS 方法路由 Args: path: 路由路径 handler: 处理函数 name: 路由名称 """ # 支持装饰器风格调用 if handler is None: def decorator(func): self.router.add_options(path, func, name) return func return decorator else: self.router.add_options(path, handler, name) return self
[文档] def add_head(self, path, handler=None, name=None): """ 添加 HEAD 方法路由 Args: path: 路由路径 handler: 处理函数 name: 路由名称 """ # 支持装饰器风格调用 if handler is None: def decorator(func): self.router.add_head(path, func, name) return func return decorator else: self.router.add_head(path, handler, name) return self
[文档] def add_static(self, prefix: str, directory: str, name: Optional[str] = None): """ 添加静态文件路由 Args: prefix: URL 前缀,如 '/static' directory: 静态文件目录路径 name: 路由名称 """ self.router.add_static(prefix, directory, name) return self
[文档] def register_routes(self, module): """ 注册模块中的路由 Args: module: 包含路由装饰器的模块对象或模块名称 """ import importlib # 如果是模块名称字符串,导入模块 if isinstance(module, str): module = importlib.import_module(module) for name in dir(module): obj = getattr(module, name) # 确保对象是可调用的,并且有 _routes 属性 if callable(obj) and hasattr(obj, '_routes'): try: # 尝试遍历 _routes for route_info in obj._routes: self.add_route( path=route_info['path'], methods=route_info['methods'], handler=obj, name=route_info['name'] ) except TypeError: # 如果 _routes 不是可迭代的,跳过 pass return self
[文档] def url_for(self, name, **kwargs): """ 根据路由名称生成 URL Args: name: 路由名称 **kwargs: 路由参数 Returns: 生成的 URL """ return self.router.url_for(name, **kwargs)
[文档] def register_plugin(self, plugin_class): """ 注册插件 Args: plugin_class: 插件类 """ self.plugin_manager.register(plugin_class) return self
[文档] def load_plugins(self): """ 加载所有插件 """ # 从文件系统加载插件 plugins = self.plugin_loader.load_plugins() for plugin_name, plugin_class in plugins.items(): self.plugin_manager.register(plugin_class) # 加载所有注册的插件 self.plugin_manager.load_all() return self
[文档] def get_plugin(self, plugin_name: str): """ 获取插件实例 Args: plugin_name: 插件名称 Returns: 插件实例或 None """ return self.plugin_manager.get_plugin(plugin_name)
[文档] def has_plugin(self, plugin_name: str) -> bool: """ 检查插件是否已加载 Args: plugin_name: 插件名称 Returns: 是否已加载 """ return self.plugin_manager.has_plugin(plugin_name)
[文档] def get_all_plugins(self): """ 获取所有已加载的插件 Returns: 插件实例列表 """ return self.plugin_manager.get_all_plugins()
[文档] def add_health_check(self, name: str, check_func): """ 添加健康检查 Args: name: 检查名称 check_func: 检查函数,返回 True 表示健康,False 表示不健康 """ from .middleware import HealthCheck for middleware in self._get_middleware_instances(): if isinstance(middleware, HealthCheck): middleware.add_check(name, check_func) break
[文档] def add_ready_check(self, name: str, check_func): """ 添加就绪检查 Args: name: 检查名称 check_func: 检查函数,返回 True 表示就绪,False 表示未就绪 """ from .middleware import HealthCheck for middleware in self._get_middleware_instances(): if isinstance(middleware, HealthCheck): middleware.add_ready_check(name, check_func) break
def _get_middleware_instances(self): """ 获取中间件实例(缓存) Returns: 中间件实例列表 """ if self._middleware_instances is None: self._middleware_instances = self.middleware_manager.get_middleware_instances(self) return self._middleware_instances
[文档] def handler(self, request, rw, environ, server): request_handler = RequestHandler(self, rw, environ, request) result = request_handler.handler() return request_handler.finish(result)
[文档] def run(self, poll_interval=0.2, processes=1, no_reload=False): import os import sys import subprocess import time import signal from pathlib import Path from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer # 获取启动脚本路径 main_file = getattr(sys.modules['__main__'], '__file__', None) if main_file: main_file = os.path.abspath(main_file) # 检查是否是子进程 is_child_process = os.environ.get('LITEFS_CHILD_PROCESS', '0') == '1' # 子进程 PID 存储,用于信号处理 child_proc = None def parent_signal_handler(signum, frame): """父进程信号处理函数,用于处理 SIGINT 和 SIGTERM 信号""" if child_proc: try: # 发送 SIGTERM 信号给子进程 child_proc.terminate() # 等待子进程关闭,最多等待 15 秒 try: child_proc.wait(timeout=15) except subprocess.TimeoutExpired: child_proc.kill() child_proc.wait() except Exception: pass sys.exit(0) def child_signal_handler(signum, frame): """子进程信号处理函数,用于处理 SIGINT 和 SIGTERM 信号""" # 关闭服务器 if hasattr(self, 'server') and self.server: try: if hasattr(self.server, 'shutdown'): self.server.shutdown() except Exception: pass # 退出进程 sys.exit(0) # 如果禁用热重载或已经是子进程,直接运行服务器 if no_reload or is_child_process: # 子进程:实际运行服务器 # 设置环境变量 os.environ['LITEFS_CHILD_PROCESS'] = '1' # 注册信号处理函数 signal.signal(signal.SIGINT, child_signal_handler) signal.signal(signal.SIGTERM, child_signal_handler) # 加载插件 self.load_plugins() log_info(self.logger, "Starting server on %s:%d (processes=%d)" % (self.host, self.port, processes)) try: if processes > 1: self.server = ProcessHTTPServer((self.host, self.port), self.handler, processes=processes) self.server.max_request_size = self.config.max_request_size self.server.server_forever(poll_interval=poll_interval) else: self.server = HTTPServer((self.host, self.port), self.handler) self.server.max_request_size = self.config.max_request_size self.server.start() mainloop(poll_interval=poll_interval) except KeyboardInterrupt: log_info(self.logger, "Server stopped by user") except SystemExit: log_info(self.logger, "Server stopped by signal") except Exception as e: log_error(self.logger, "Server error: %s" % str(e)) finally: # 关闭服务器 if hasattr(self, 'server') and self.server: try: # 先关闭所有工作进程 if hasattr(self.server, 'shutdown'): self.server.shutdown() # 给工作进程足够的时间关闭 time.sleep(2) # 再关闭服务器套接字 self.server.server_close() except Exception: pass # 关闭数据库连接 try: from .database import DatabaseManager DatabaseManager.close_all() except Exception: pass else: # 父进程,启动子进程监控 # 注册信号处理 signal.signal(signal.SIGINT, parent_signal_handler) signal.signal(signal.SIGTERM, parent_signal_handler) # 启动子进程循环 while True: # 启动子进程 try: # 准备环境变量 env = os.environ.copy() env['LITEFS_CHILD_PROCESS'] = '1' # 启动子进程 child_proc = subprocess.Popen( [sys.executable] + sys.argv, env=env, close_fds=True ) # 简单的文件监控(暂时跳过 watchdog 避免复杂问题) if main_file: project_dir = os.path.dirname(main_file) # 监控文件变化 file_mod_times = {} for root, dirs, files in os.walk(project_dir): for file in files: if file.endswith('.py'): file_path = os.path.join(root, file) file_mod_times[file_path] = os.path.getmtime(file_path) # 等待子进程或文件变化 while child_proc.poll() is None: # 检查文件变化 should_restart = False try: for root, dirs, files in os.walk(project_dir): for file in files: if file.endswith('.py'): file_path = os.path.join(root, file) if file_path in file_mod_times: try: current_mtime = os.path.getmtime(file_path) if current_mtime != file_mod_times[file_path]: should_restart = True break except FileNotFoundError: # 文件被删除 should_restart = True break if should_restart: break except Exception: pass if should_restart: child_proc.terminate() try: child_proc.wait(timeout=15) except subprocess.TimeoutExpired: child_proc.kill() child_proc.wait() break time.sleep(1) else: # 没有主文件,直接等待子进程退出 child_proc.wait() # 子进程已退出 exit_code = child_proc.returncode # 如果是文件变化导致的重启,继续循环 if 'should_restart' in locals() and should_restart: # 等待一段时间,确保所有旧的工作进程都已经完全关闭 # 检查端口是否已经释放 start_time = time.time() while time.time() - start_time < 30: # 最多等待 30 秒 if is_port_available(self.host, self.port): break time.sleep(1) continue else: break except Exception: time.sleep(1) finally: # 确保子进程被关闭 if child_proc and child_proc.poll() is None: try: child_proc.terminate() child_proc.wait(timeout=15) except Exception: try: child_proc.kill() child_proc.wait() except Exception: pass
def _cmd_args(args): title = args[0] if args else "litefs" parser = argparse.ArgumentParser(title, description=__doc__) parser.add_argument( "-H", "--host", dest="host", required=False, default="localhost", help="bind server to HOST" ) parser.add_argument( "-P", "--port", action="store", dest="port", type=int, required=False, default=9090, help="bind server to PORT", ) parser.add_argument( "--debug", action="store_true", dest="debug", required=False, default=False, help="start server in debug mode", ) parser.add_argument( "--log", dest="log", required=False, default="./default.log", help="save log to LOG" ) parser.add_argument( "--listen", dest="listen", type=int, required=False, default=1024, help="server LISTEN" ) parser.add_argument( "--max-request-size", dest="max_request_size", required=False, type=int, default=10485760, help="maximum request body size in bytes (default: 10MB)", ) parser.add_argument( "--max-upload-size", dest="max_upload_size", required=False, type=int, default=52428800, help="maximum file upload size in bytes (default: 50MB)", ) parser.add_argument( "--config", dest="config_file", required=False, default=None, help="path to configuration file (YAML, JSON, or TOML)", ) parser.add_argument( "--processes", dest="processes", type=int, required=False, default=1, help="number of worker processes (default: 1)", ) args = parser.parse_args(args and args[1:]) return args
[文档] def test_server(): from litefs.middleware import ( CORSMiddleware, LoggingMiddleware, SecurityMiddleware, ) args = _cmd_args(sys.argv) kwargs = vars(args) processes = kwargs.pop('processes', 1) litefs = ( Litefs(**kwargs) .add_middleware(LoggingMiddleware) .add_middleware(SecurityMiddleware) .add_middleware(CORSMiddleware) .add_middleware(CSRFMiddleware) ) litefs.run(poll_interval=0.1, processes=processes)
if "__main__" == __name__: test_server()