#!/usr/bin/env python
# coding: utf-8
from errno import EAGAIN, EMFILE, ENOTCONN, EPIPE, EWOULDBLOCK
from functools import lru_cache, partial
from io import DEFAULT_BUFFER_SIZE, BufferedRWPair, RawIOBase
from posixpath import abspath as path_abspath
from posixpath import join as path_join
from time import time
from traceback import print_exc
from urllib.parse import unquote_plus
from email.message import Message
try:
from greenlet import GreenletExit, getcurrent, greenlet
HAS_GREENLET = True
except ImportError:
HAS_GREENLET = False
greenlet = None
getcurrent = None
GreenletExit = None
try:
from select import EPOLLERR, EPOLLET, EPOLLHUP, EPOLLIN, EPOLLOUT
from select import epoll as select_epoll
HAS_EPOLL = True
except (ImportError, AttributeError):
HAS_EPOLL = False
EPOLLIN = EPOLLOUT = EPOLLHUP = EPOLLERR = EPOLLET = 0
select_epoll = None
from ..exceptions import HttpError
from ..handlers import RequestHandler, parse_form
from ..utils import log_error
import array
import traceback
import logging
import socket
import sys
import os
import multiprocessing
import time
should_retry_error = (EWOULDBLOCK, EAGAIN)
[文档]
def make_environ(server, rw, client_address):
environ = dict()
environ["SERVER_NAME"] = server.server_name
environ["SERVER_SOFTWARE"] = "litefs/0.4.0"
environ["SERVER_PORT"] = server.server_port
environ["REMOTE_ADDR"] = client_address[0]
environ["REMOTE_HOST"] = client_address[0]
environ["REMOTE_PORT"] = client_address[1]
s = rw.readline(DEFAULT_BUFFER_SIZE)
s = s.decode("utf-8")
if not s:
raise HttpError("invalid http headers")
request_method, path_info, protocol = s.strip().split()
if "?" in path_info:
path_info, query_string = path_info.split("?", 1)
else:
path_info, query_string = path_info, ""
path_info = unquote_plus(path_info)
base_uri, script_name = path_info.split("/", 1)
if "" == script_name:
script_name = "index.html"
environ["REQUEST_METHOD"] = request_method.upper()
environ["QUERY_STRING"] = unquote_plus(query_string)
environ["SERVER_PROTOCOL"] = protocol
environ["SCRIPT_NAME"] = script_name
environ["PATH_INFO"] = path_info
headers = make_headers(rw)
length = headers.get("content-length")
content_type = headers.get("content-type")
if content_type:
environ["CONTENT_TYPE"] = content_type
else:
environ["CONTENT_TYPE"] = content_type = "text/plain; charset=utf-8"
if length:
environ["CONTENT_LENGTH"] = length = int(length)
if hasattr(server, 'max_request_size') and length > server.max_request_size:
raise HttpError(413, "Request Entity Too Large")
_, params = parse_header(content_type)
charset = params.get("charset")
environ["CHARSET"] = charset
for k, v in headers.items():
k = k.replace("-", "_").upper()
if k in environ:
continue
k = f"HTTP_{k}"
environ[k] = v
return environ
[文档]
class SocketIO(RawIOBase):
def __init__(self, server, sock):
RawIOBase.__init__(self)
self._fileno = sock.fileno()
self._sock = sock
self._server = server
[文档]
def fileno(self):
return self._fileno
[文档]
def readable(self):
return True
[文档]
def writable(self):
return True
[文档]
def readinto(self, b):
real_epoll = epoll._epoll
fileno = self._fileno
curr = getcurrent()
self.read_gr = curr
if self.write_gr is None:
real_epoll.register(fileno, EPOLLIN | EPOLLET)
else:
real_epoll.modify(fileno, EPOLLIN | EPOLLOUT | EPOLLET)
data = b""
try:
curr.parent.switch()
data = self._sock.recv(len(b))
except socket.error as e:
if e.errno not in should_retry_error:
raise
finally:
self.read_gr = None
if self.write_gr is None:
real_epoll.unregister(fileno)
else:
real_epoll.modify(fileno, EPOLLOUT | EPOLLET)
n = len(data)
try:
b[:n] = data
except TypeError as err:
if not isinstance(b, array.array):
raise err
b[:n] = array.array(b"b", data)
return n
[文档]
def write(self, data):
real_epoll = epoll._epoll
fileno = self._fileno
curr = getcurrent()
self.write_gr = curr
if self.read_gr is None:
real_epoll.register(fileno, EPOLLOUT | EPOLLET)
else:
real_epoll.modify(fileno, EPOLLIN | EPOLLOUT | EPOLLET)
try:
curr.parent.switch()
return self._sock.send(data)
except socket.error as e:
if e.errno not in should_retry_error:
raise
finally:
self.write_gr = None
if self.read_gr is None:
real_epoll.unregister(fileno)
else:
real_epoll.modify(fileno, EPOLLIN | EPOLLET)
[文档]
def close(self):
if self.closed:
return
RawIOBase.close(self)
try:
try:
self._sock.shutdown(socket.SHUT_RDWR)
except socket.error as e:
if e.errno != ENOTCONN:
raise
finally:
try:
self._sock.close()
except:
pass
read_gr = write_gr = None
[文档]
class Epoll(object):
def __init__(self):
self._epoll = select_epoll()
self._servers = {}
self._greenlets = {}
self._idles = []
[文档]
def register(self, server_socket):
servers = self._servers
fileno = server_socket.fileno()
servers[fileno] = server_socket
self._epoll.register(fileno, EPOLLIN | EPOLLET)
[文档]
def unregister(self, server_socket):
servers = self._servers
fileno = server_socket.fileno()
if fileno in servers:
self._epoll.unregister(fileno)
del servers[fileno]
[文档]
def close(self):
for fileno, server_socket in self._servers.items():
self._epoll.unregister(fileno)
server_socket.server_close()
self._epoll.close()
[文档]
def poll(self, poll_interval=0.2):
servers = self._servers
greenlets = self._greenlets
_poll = self._epoll.poll
idles = self._idles
while True:
events = _poll(poll_interval)
for fileno, event in events:
if fileno in servers:
server = servers[fileno]
try:
server.handle_request()
except KeyboardInterrupt:
break
except socket.error as e:
if e.errno == EMFILE:
raise
print_exc()
except Exception:
print_exc()
elif fileno in greenlets:
# 只处理还在 greenlets 字典中的连接
if (event & EPOLLIN) or (event & EPOLLOUT):
try:
greenlets[fileno].switch()
except KeyboardInterrupt:
break
except Exception:
# greenlet 执行出错,清理连接
gr = greenlets.pop(fileno, None)
if gr is not None:
gr.throw()
print_exc()
elif event & (EPOLLHUP | EPOLLERR):
try:
# 发送异常给 greenlet,让它退出
greenlets[fileno].throw()
except KeyboardInterrupt:
break
except Exception:
pass
finally:
# 清理 greenlet
gr = greenlets.pop(fileno, None)
if gr is not None:
gr.throw()
while len(idles):
now_ts = time()
ts, gr = idles.pop(0)
if ts > now_ts:
idles.append((ts, gr))
idles.sort()
break
else:
gr.switch()
[文档]
class TCPServer(object):
"""Classic Python TCPServer"""
allow_reuse_address = True
request_queue_size = 4194304
address_family, socket_type = socket.AF_INET, socket.SOCK_STREAM
def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True):
self.server_address = server_address
self.RequestHandlerClass = RequestHandlerClass
self.socket = socket.socket(self.address_family, self.socket_type)
self._started = False
if bind_and_activate:
try:
self.server_bind()
self.server_activate()
except Exception:
self.server_close()
raise
[文档]
def server_bind(self):
if self.allow_reuse_address:
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# 启用 SO_REUSEPORT,允许多个进程绑定到同一个端口
try:
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except (AttributeError, OSError):
# 某些系统可能不支持 SO_REUSEPORT
pass
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
logging.info("bind %s:%s", *self.server_address)
self.socket.bind(self.server_address)
self.socket.setblocking(0)
[文档]
def server_activate(self):
self.socket.listen(self.request_queue_size)
[文档]
def server_close(self):
self.socket.close()
[文档]
def fileno(self):
return self.socket.fileno()
[文档]
def get_request(self):
return self.socket.accept()
[文档]
def handle_request(self):
self._handle_request_noblock()
def _handle_request_noblock(self):
while True:
try:
request, client_address = self.get_request()
except socket.error as e:
errno = e.args[0]
if EAGAIN == errno or EWOULDBLOCK == errno:
return
raise
if self.verify_request(request, client_address):
try:
self.process_request(request, client_address)
except Exception:
self.handle_error(request, client_address)
self.shutdown_request(request)
else:
self.shutdown_request(request)
[文档]
def handle_timeout(self):
pass
[文档]
def verify_request(self, request, client_address):
return True
[文档]
def process_request(self, request, client_address):
try:
self.finish_request(request, client_address)
except Exception:
self.handle_error(request, client_address)
self.shutdown_request(request)
[文档]
def finish_request(self, request, client_address):
request.setblocking(0)
fileno = request.fileno()
epoll._greenlets[fileno] = curr = greenlet(
partial(self._finish_request, request, client_address)
)
curr.switch()
def _finish_request(self, request, client_address):
raw = SocketIO(self, request)
try:
rw = BufferedRWPair(raw, raw, DEFAULT_BUFFER_SIZE)
environ = make_environ(self, rw, client_address)
self.RequestHandlerClass(request, rw, environ, self)
self.shutdown_request(request)
except socket.error as e:
if e.errno == EPIPE:
raise GreenletExit
raise
except HttpError as e:
# Send HTTP error response
try:
status_code = e.status_code
message = e.message
response = f"HTTP/1.1 {status_code} {message}\r\n"
response += "Content-Type: text/html; charset=utf-8\r\n"
response += "Content-Length: %d\r\n" % len(message)
response += "\r\n"
response += message
# Write the response
rw.write(response.encode('utf-8'))
# Flush the buffer to ensure the response is sent
if hasattr(rw, 'flush'):
rw.flush()
except Exception:
pass
finally:
self.shutdown_request(request)
except Exception as e:
raise
finally:
# 确保在所有情况下都清理 greenlet 和关闭连接
try:
try:
if raw.read_gr is not None:
raw.read_gr.throw()
if raw.write_gr is not None:
raw.write_gr.throw()
finally:
fileno = raw.fileno()
gr = epoll._greenlets.pop(fileno, None)
if gr is not None:
gr.throw()
finally:
# 确保连接被关闭
if not raw.closed:
try:
raw.close()
except Exception:
pass
[文档]
def shutdown_request(self, request):
try:
request.shutdown(socket.SHUT_WR)
except OSError:
pass
self.close_request(request)
[文档]
def close_request(self, request):
request.close()
[文档]
def handle_error(self, request, client_address):
traceback.print_exc()
[文档]
def server_forever(self, poll_interval=0.1):
if not self._started:
epoll.register(self)
mainloop(poll_interval=poll_interval)
[文档]
def start(self):
if not self._started:
epoll.register(self)
self._started = True
[文档]
def shutdown(self):
if self._started:
epoll.unregister(self)
self._started = False
[文档]
class HTTPServer(TCPServer):
allow_reuse_address = 1
max_request_size = 10485760
[文档]
def server_bind(self):
TCPServer.server_bind(self)
host, port = self.socket.getsockname()[:2]
self.server_name = socket.getfqdn(host)
self.server_port = port
[文档]
class ProcessHTTPServer(HTTPServer):
"""多进程 HTTP 服务器"""
def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True, processes=4):
"""
初始化多进程 HTTP 服务器
Args:
server_address: 服务器地址
RequestHandlerClass: 请求处理器类
bind_and_activate: 是否绑定和激活
processes: 进程数,默认为 4
"""
# 延迟绑定端口,避免热重载时端口被占用
super().__init__(server_address, RequestHandlerClass, False)
self.processes = processes
self.workers = []
[文档]
def server_forever(self, poll_interval=0.1):
"""启动多进程服务器"""
# 绑定和激活服务器
self.server_bind()
self.server_activate()
# 启动多个进程
for i in range(self.processes):
worker = multiprocessing.Process(target=self._run_worker, args=(i, poll_interval))
worker.daemon = True
worker.start()
self.workers.append(worker)
logging.info(f"启动工作进程 {i}")
# 等待所有进程结束
try:
for worker in self.workers:
worker.join()
except KeyboardInterrupt:
logging.info("接收到中断信号,停止服务器")
self.shutdown()
def _run_worker(self, worker_id, poll_interval):
"""运行工作进程"""
logging.info(f"工作进程 {worker_id} 启动,PID: {os.getpid()}")
try:
# 每个进程创建自己的 epoll 实例
global epoll
epoll = Epoll() if HAS_EPOLL else None
# 运行主循环
if HAS_EPOLL:
# 注册服务器套接字到 epoll
if not hasattr(self, '_started') or not self._started:
epoll.register(self)
self._started = True
# 运行 epoll 循环
while True:
epoll.poll(poll_interval=poll_interval)
else:
# 运行传统循环
while True:
self.handle_request()
except Exception as e:
logging.error(f"工作进程 {worker_id} 出错: {e}")
import traceback
traceback.print_exc()
def _check_port_free(self, host, port, timeout=5):
"""检查端口是否已释放
Args:
host: 主机名
port: 端口号
timeout: 超时时间(秒)
Returns:
bool: 端口是否已释放
"""
start_time = time.time()
while time.time() - start_time < timeout:
try:
# 尝试绑定到端口,如果成功则说明端口已释放
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
sock.close()
return True
except socket.error:
# 端口被占用,继续等待
time.sleep(0.1)
return False
def _get_all_children(self, process):
"""递归获取所有子进程和孙子进程
Args:
process: 父进程对象
Returns:
list: 所有子进程列表
"""
all_children = []
try:
# 尝试获取子进程
# 注意:multiprocessing.Process没有直接获取子进程的方法
# 这里我们通过psutil库来实现(如果可用)
try:
import psutil
parent = psutil.Process(process.pid)
for child in parent.children(recursive=True):
all_children.append(child)
except ImportError:
# 如果没有psutil,只检查直接子进程
# 注意:这可能无法获取所有孙子进程
pass
except Exception:
pass
return all_children
[文档]
def shutdown(self):
"""关闭服务器"""
import signal
# 记录服务器地址和端口
host, port = self.server_address
# 只在主进程中打印关闭信息
if os.getpid() == os.getppid() or len(self.workers) > 0:
logging.info("开始关闭服务器...")
# 终止所有工作进程
alive_workers = [w for w in self.workers if w.is_alive()]
if alive_workers:
logging.info(f"发现 {len(alive_workers)} 个活跃的工作进程")
# 首先尝试优雅地终止所有工作进程
for i, worker in enumerate(alive_workers):
if worker.is_alive():
try:
# 发送 SIGTERM 信号
worker.terminate()
except Exception:
pass
# 等待工作进程关闭
if alive_workers:
logging.info("等待工作进程关闭...")
start_time = time.time()
timeout = 10 # 最多等待 10 秒
while time.time() - start_time < timeout:
all_dead = True
for worker in self.workers:
if worker.is_alive():
all_dead = False
worker.join(timeout=0.1)
if all_dead:
if alive_workers:
logging.info("所有工作进程已关闭")
break
time.sleep(0.1)
# 如果还有进程存活,强制终止
for i, worker in enumerate(self.workers):
if worker.is_alive():
try:
worker.kill()
worker.join(timeout=1)
except Exception:
pass
# 清空工作进程列表
self.workers = []
# 关闭服务器套接字
try:
self.server_close()
if alive_workers:
logging.info("服务器套接字已关闭")
except Exception:
pass
# 尝试使用 psutil 找到并终止所有相关进程
try:
import psutil
current_pid = os.getpid()
# 查找所有监听指定端口的进程
for proc in psutil.process_iter(['pid', 'name', 'connections']):
try:
# 跳过当前进程
if proc.info['pid'] == current_pid:
continue
# 检查进程是否监听指定端口
connections = proc.info.get('connections', [])
if connections:
for conn in connections:
if conn.status == 'LISTEN' and conn.laddr.port == port:
# 终止进程及其所有子进程
for child in proc.children(recursive=True):
try:
child.terminate()
except Exception:
pass
try:
proc.terminate()
proc.wait(timeout=5)
except psutil.TimeoutExpired:
proc.kill()
proc.wait()
except Exception:
pass
break
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
except Exception:
pass
except ImportError:
pass
except Exception:
pass
# 等待一段时间让端口释放
time.sleep(1)
# 检查端口是否已释放
if self._check_port_free(host, port, timeout=5):
if alive_workers:
logging.info("端口已释放,服务器已完全关闭")
else:
if alive_workers:
logging.warning("端口可能未完全释放,服务器已尝试关闭")
if alive_workers:
logging.info("服务器已关闭")
[文档]
class WSGIServer(HTTPServer):
application = None
[文档]
def server_bind(self):
HTTPServer.server_bind(self)
self.setup_environ()
[文档]
def setup_environ(self):
env = {}
env["SERVER_NAME"] = self.server_name
env["SERVER_PORT"] = str(self.server_port)
env["REMOTE_HOST"] = ""
env["CONTENT_LENGTH"] = -1
env["SCRIPT_NAME"] = ""
self.base_environ = env
[文档]
def get_app(self):
return self.application
[文档]
def set_app(self, application):
self.application = application
[文档]
def mainloop(poll_interval=0.1):
try:
epoll.poll(poll_interval=poll_interval)
except KeyboardInterrupt:
pass
server_forever = mainloop
epoll = Epoll() if HAS_EPOLL else None