Fix binding token extraction and harden startup concurrency

This commit is contained in:
2026-03-05 14:40:27 +08:00
parent feb99faaf3
commit 7ed6f70bab
9 changed files with 96 additions and 17 deletions

View File

@@ -8,7 +8,9 @@ ADMIN_JWT_SECRET=replace-with-a-random-jwt-secret
TRUSTED_PROXY_IPS=172.30.0.0/24
SENTINEL_FAILSAFE_MODE=closed
APP_PORT=7000
UVICORN_WORKERS=4
ALERT_WEBHOOK_URL=
ALERT_THRESHOLD_COUNT=5
ALERT_THRESHOLD_SECONDS=300
ARCHIVE_DAYS=90
ARCHIVE_SCHEDULER_LOCK_KEY=2026030502

View File

@@ -6,4 +6,4 @@ RUN pip install --no-cache-dir --prefix=/install -r requirements.txt
FROM python:3.13-slim-bookworm
WORKDIR /app
COPY --from=builder /install /usr/local
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7000", "--workers", "4"]
CMD ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port ${APP_PORT:-7000} --workers ${UVICORN_WORKERS:-4}"]

View File

@@ -54,6 +54,7 @@ class Settings(BaseSettings):
admin_jwt_expire_hours: int = 8
archive_job_interval_minutes: int = 60
archive_batch_size: int = 500
archive_scheduler_lock_key: int = Field(default=2026030502, alias="ARCHIVE_SCHEDULER_LOCK_KEY")
metrics_ttl_days: int = 30
webhook_timeout_seconds: int = 5

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import hashlib
import hmac
from datetime import UTC, datetime, timedelta
from typing import Mapping
from fastapi import HTTPException, status
from jose import JWTError, jwt
@@ -34,6 +35,19 @@ def extract_bearer_token(authorization: str | None) -> str | None:
return token.strip()
def extract_request_token(headers: Mapping[str, str]) -> tuple[str | None, str | None]:
bearer_token = extract_bearer_token(headers.get("authorization"))
if bearer_token:
return bearer_token, "authorization"
for header_name in ("x-api-key", "api-key"):
header_value = headers.get(header_name)
if header_value and header_value.strip():
return header_value.strip(), header_name
return None, None
def verify_admin_password(password: str, settings: Settings) -> bool:
return hmac.compare_digest(password, settings.admin_password)

View File

@@ -14,7 +14,7 @@ from redis.asyncio import from_url as redis_from_url
from app.api import auth, bindings, dashboard, logs, settings as settings_api
from app.config import RUNTIME_SETTINGS_REDIS_KEY, RuntimeSettings, Settings, get_settings
from app.models import intercept_log, token_binding # noqa: F401
from app.models.db import close_db, ensure_schema_compatibility, get_session_factory, init_db
from app.models.db import close_db, ensure_schema_compatibility, get_engine, get_session_factory, init_db
from app.proxy.handler import router as proxy_router
from app.services.alert_service import AlertService
from app.services.archive_service import ArchiveService
@@ -70,6 +70,8 @@ def configure_logging() -> None:
root_logger.handlers.clear()
root_logger.addHandler(handler)
root_logger.setLevel(logging.INFO)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
configure_logging()
@@ -153,6 +155,7 @@ async def lifespan(app: FastAPI):
)
archive_service = ArchiveService(
settings=settings,
engine=get_engine(),
session_factory=session_factory,
binding_service=binding_service,
runtime_settings_getter=lambda: app.state.runtime_settings,

View File

@@ -6,6 +6,8 @@ from sqlalchemy.orm import DeclarativeBase
from app.config import Settings
SCHEMA_COMPATIBILITY_LOCK_KEY = 2026030501
class Base(DeclarativeBase):
pass
@@ -62,6 +64,10 @@ async def ensure_schema_compatibility() -> None:
"CREATE INDEX IF NOT EXISTS idx_token_bindings_ip ON token_bindings(bound_ip)",
]
async with engine.begin() as connection:
await connection.execute(
text("SELECT pg_advisory_xact_lock(:lock_key)"),
{"lock_key": SCHEMA_COMPATIBILITY_LOCK_KEY},
)
for statement in statements:
await connection.execute(text(statement))

View File

@@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
from app.config import Settings
from app.core.ip_utils import extract_client_ip
from app.core.security import extract_bearer_token
from app.core.security import extract_request_token
from app.dependencies import get_alert_service, get_binding_service, get_settings
from app.services.alert_service import AlertService
from app.services.binding_service import BindingService
@@ -56,7 +56,7 @@ async def reverse_proxy(
alert_service: AlertService = Depends(get_alert_service),
):
client_ip = extract_client_ip(request, settings)
token = extract_bearer_token(request.headers.get("authorization"))
token, token_source = extract_request_token(request.headers)
if token:
binding_result = await binding_service.evaluate_token_binding(token, client_ip)
@@ -75,6 +75,7 @@ async def reverse_proxy(
status_code=binding_result.status_code,
content={"detail": binding_result.detail},
)
logger.debug("Token binding check passed.", extra={"client_ip": client_ip, "token_source": token_source})
else:
await binding_service.increment_request_metric("allowed")

View File

@@ -5,9 +5,9 @@ from datetime import UTC, datetime, timedelta
from typing import Callable
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from sqlalchemy import delete, select
from sqlalchemy import delete, select, text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker
from app.config import RuntimeSettings, Settings
from app.models.token_binding import TokenBinding
@@ -20,33 +20,45 @@ class ArchiveService:
def __init__(
self,
settings: Settings,
engine: AsyncEngine,
session_factory: async_sessionmaker[AsyncSession],
binding_service: BindingService,
runtime_settings_getter: Callable[[], RuntimeSettings],
) -> None:
self.settings = settings
self.engine = engine
self.session_factory = session_factory
self.binding_service = binding_service
self.runtime_settings_getter = runtime_settings_getter
self.scheduler = AsyncIOScheduler(timezone="UTC")
self._leader_connection: AsyncConnection | None = None
async def start(self) -> None:
if self.scheduler.running:
return
self.scheduler.add_job(
self.archive_inactive_bindings,
trigger="interval",
minutes=self.settings.archive_job_interval_minutes,
id="archive-inactive-bindings",
replace_existing=True,
max_instances=1,
coalesce=True,
)
self.scheduler.start()
if not await self._acquire_leader_lock():
logger.info("Archive scheduler leader lock not acquired; skipping local scheduler start.")
return
try:
self.scheduler.add_job(
self.archive_inactive_bindings,
trigger="interval",
minutes=self.settings.archive_job_interval_minutes,
id="archive-inactive-bindings",
replace_existing=True,
max_instances=1,
coalesce=True,
)
self.scheduler.start()
except Exception:
await self._release_leader_lock()
raise
logger.info("Archive scheduler started on current worker.")
async def stop(self) -> None:
if self.scheduler.running:
self.scheduler.shutdown(wait=False)
await self._release_leader_lock()
async def archive_inactive_bindings(self) -> int:
runtime_settings = self.runtime_settings_getter()
@@ -82,3 +94,43 @@ class ArchiveService:
if total_archived:
logger.info("Archived inactive bindings.", extra={"count": total_archived})
return total_archived
async def _acquire_leader_lock(self) -> bool:
if self._leader_connection is not None:
return True
connection = await self.engine.connect()
try:
acquired = bool(
await connection.scalar(
text("SELECT pg_try_advisory_lock(:lock_key)"),
{"lock_key": self.settings.archive_scheduler_lock_key},
)
)
except Exception:
await connection.close()
logger.exception("Failed to acquire archive scheduler leader lock.")
return False
if not acquired:
await connection.close()
return False
self._leader_connection = connection
return True
async def _release_leader_lock(self) -> None:
if self._leader_connection is None:
return
connection = self._leader_connection
self._leader_connection = None
try:
await connection.execute(
text("SELECT pg_advisory_unlock(:lock_key)"),
{"lock_key": self.settings.archive_scheduler_lock_key},
)
except Exception:
logger.warning("Failed to release archive scheduler leader lock cleanly.")
finally:
await connection.close()

View File

@@ -1,4 +1,4 @@
worker_processes auto;
worker_processes 8;
events {
worker_connections 4096;