Fix binding token extraction and harden startup concurrency
This commit is contained in:
@@ -8,7 +8,9 @@ ADMIN_JWT_SECRET=replace-with-a-random-jwt-secret
|
|||||||
TRUSTED_PROXY_IPS=172.30.0.0/24
|
TRUSTED_PROXY_IPS=172.30.0.0/24
|
||||||
SENTINEL_FAILSAFE_MODE=closed
|
SENTINEL_FAILSAFE_MODE=closed
|
||||||
APP_PORT=7000
|
APP_PORT=7000
|
||||||
|
UVICORN_WORKERS=4
|
||||||
ALERT_WEBHOOK_URL=
|
ALERT_WEBHOOK_URL=
|
||||||
ALERT_THRESHOLD_COUNT=5
|
ALERT_THRESHOLD_COUNT=5
|
||||||
ALERT_THRESHOLD_SECONDS=300
|
ALERT_THRESHOLD_SECONDS=300
|
||||||
ARCHIVE_DAYS=90
|
ARCHIVE_DAYS=90
|
||||||
|
ARCHIVE_SCHEDULER_LOCK_KEY=2026030502
|
||||||
|
|||||||
@@ -6,4 +6,4 @@ RUN pip install --no-cache-dir --prefix=/install -r requirements.txt
|
|||||||
FROM python:3.13-slim-bookworm
|
FROM python:3.13-slim-bookworm
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY --from=builder /install /usr/local
|
COPY --from=builder /install /usr/local
|
||||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7000", "--workers", "4"]
|
CMD ["sh", "-c", "uvicorn app.main:app --host 0.0.0.0 --port ${APP_PORT:-7000} --workers ${UVICORN_WORKERS:-4}"]
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class Settings(BaseSettings):
|
|||||||
admin_jwt_expire_hours: int = 8
|
admin_jwt_expire_hours: int = 8
|
||||||
archive_job_interval_minutes: int = 60
|
archive_job_interval_minutes: int = 60
|
||||||
archive_batch_size: int = 500
|
archive_batch_size: int = 500
|
||||||
|
archive_scheduler_lock_key: int = Field(default=2026030502, alias="ARCHIVE_SCHEDULER_LOCK_KEY")
|
||||||
metrics_ttl_days: int = 30
|
metrics_ttl_days: int = 30
|
||||||
webhook_timeout_seconds: int = 5
|
webhook_timeout_seconds: int = 5
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
@@ -34,6 +35,19 @@ def extract_bearer_token(authorization: str | None) -> str | None:
|
|||||||
return token.strip()
|
return token.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_request_token(headers: Mapping[str, str]) -> tuple[str | None, str | None]:
|
||||||
|
bearer_token = extract_bearer_token(headers.get("authorization"))
|
||||||
|
if bearer_token:
|
||||||
|
return bearer_token, "authorization"
|
||||||
|
|
||||||
|
for header_name in ("x-api-key", "api-key"):
|
||||||
|
header_value = headers.get(header_name)
|
||||||
|
if header_value and header_value.strip():
|
||||||
|
return header_value.strip(), header_name
|
||||||
|
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def verify_admin_password(password: str, settings: Settings) -> bool:
|
def verify_admin_password(password: str, settings: Settings) -> bool:
|
||||||
return hmac.compare_digest(password, settings.admin_password)
|
return hmac.compare_digest(password, settings.admin_password)
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from redis.asyncio import from_url as redis_from_url
|
|||||||
from app.api import auth, bindings, dashboard, logs, settings as settings_api
|
from app.api import auth, bindings, dashboard, logs, settings as settings_api
|
||||||
from app.config import RUNTIME_SETTINGS_REDIS_KEY, RuntimeSettings, Settings, get_settings
|
from app.config import RUNTIME_SETTINGS_REDIS_KEY, RuntimeSettings, Settings, get_settings
|
||||||
from app.models import intercept_log, token_binding # noqa: F401
|
from app.models import intercept_log, token_binding # noqa: F401
|
||||||
from app.models.db import close_db, ensure_schema_compatibility, get_session_factory, init_db
|
from app.models.db import close_db, ensure_schema_compatibility, get_engine, get_session_factory, init_db
|
||||||
from app.proxy.handler import router as proxy_router
|
from app.proxy.handler import router as proxy_router
|
||||||
from app.services.alert_service import AlertService
|
from app.services.alert_service import AlertService
|
||||||
from app.services.archive_service import ArchiveService
|
from app.services.archive_service import ArchiveService
|
||||||
@@ -70,6 +70,8 @@ def configure_logging() -> None:
|
|||||||
root_logger.handlers.clear()
|
root_logger.handlers.clear()
|
||||||
root_logger.addHandler(handler)
|
root_logger.addHandler(handler)
|
||||||
root_logger.setLevel(logging.INFO)
|
root_logger.setLevel(logging.INFO)
|
||||||
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
@@ -153,6 +155,7 @@ async def lifespan(app: FastAPI):
|
|||||||
)
|
)
|
||||||
archive_service = ArchiveService(
|
archive_service = ArchiveService(
|
||||||
settings=settings,
|
settings=settings,
|
||||||
|
engine=get_engine(),
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
binding_service=binding_service,
|
binding_service=binding_service,
|
||||||
runtime_settings_getter=lambda: app.state.runtime_settings,
|
runtime_settings_getter=lambda: app.state.runtime_settings,
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from sqlalchemy.orm import DeclarativeBase
|
|||||||
|
|
||||||
from app.config import Settings
|
from app.config import Settings
|
||||||
|
|
||||||
|
SCHEMA_COMPATIBILITY_LOCK_KEY = 2026030501
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
pass
|
pass
|
||||||
@@ -62,6 +64,10 @@ async def ensure_schema_compatibility() -> None:
|
|||||||
"CREATE INDEX IF NOT EXISTS idx_token_bindings_ip ON token_bindings(bound_ip)",
|
"CREATE INDEX IF NOT EXISTS idx_token_bindings_ip ON token_bindings(bound_ip)",
|
||||||
]
|
]
|
||||||
async with engine.begin() as connection:
|
async with engine.begin() as connection:
|
||||||
|
await connection.execute(
|
||||||
|
text("SELECT pg_advisory_xact_lock(:lock_key)"),
|
||||||
|
{"lock_key": SCHEMA_COMPATIBILITY_LOCK_KEY},
|
||||||
|
)
|
||||||
for statement in statements:
|
for statement in statements:
|
||||||
await connection.execute(text(statement))
|
await connection.execute(text(statement))
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||||||
|
|
||||||
from app.config import Settings
|
from app.config import Settings
|
||||||
from app.core.ip_utils import extract_client_ip
|
from app.core.ip_utils import extract_client_ip
|
||||||
from app.core.security import extract_bearer_token
|
from app.core.security import extract_request_token
|
||||||
from app.dependencies import get_alert_service, get_binding_service, get_settings
|
from app.dependencies import get_alert_service, get_binding_service, get_settings
|
||||||
from app.services.alert_service import AlertService
|
from app.services.alert_service import AlertService
|
||||||
from app.services.binding_service import BindingService
|
from app.services.binding_service import BindingService
|
||||||
@@ -56,7 +56,7 @@ async def reverse_proxy(
|
|||||||
alert_service: AlertService = Depends(get_alert_service),
|
alert_service: AlertService = Depends(get_alert_service),
|
||||||
):
|
):
|
||||||
client_ip = extract_client_ip(request, settings)
|
client_ip = extract_client_ip(request, settings)
|
||||||
token = extract_bearer_token(request.headers.get("authorization"))
|
token, token_source = extract_request_token(request.headers)
|
||||||
|
|
||||||
if token:
|
if token:
|
||||||
binding_result = await binding_service.evaluate_token_binding(token, client_ip)
|
binding_result = await binding_service.evaluate_token_binding(token, client_ip)
|
||||||
@@ -75,6 +75,7 @@ async def reverse_proxy(
|
|||||||
status_code=binding_result.status_code,
|
status_code=binding_result.status_code,
|
||||||
content={"detail": binding_result.detail},
|
content={"detail": binding_result.detail},
|
||||||
)
|
)
|
||||||
|
logger.debug("Token binding check passed.", extra={"client_ip": client_ip, "token_source": token_source})
|
||||||
else:
|
else:
|
||||||
await binding_service.increment_request_metric("allowed")
|
await binding_service.increment_request_metric("allowed")
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ from datetime import UTC, datetime, timedelta
|
|||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
from sqlalchemy import delete, select
|
from sqlalchemy import delete, select, text
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
from app.config import RuntimeSettings, Settings
|
from app.config import RuntimeSettings, Settings
|
||||||
from app.models.token_binding import TokenBinding
|
from app.models.token_binding import TokenBinding
|
||||||
@@ -20,19 +20,26 @@ class ArchiveService:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
|
engine: AsyncEngine,
|
||||||
session_factory: async_sessionmaker[AsyncSession],
|
session_factory: async_sessionmaker[AsyncSession],
|
||||||
binding_service: BindingService,
|
binding_service: BindingService,
|
||||||
runtime_settings_getter: Callable[[], RuntimeSettings],
|
runtime_settings_getter: Callable[[], RuntimeSettings],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
|
self.engine = engine
|
||||||
self.session_factory = session_factory
|
self.session_factory = session_factory
|
||||||
self.binding_service = binding_service
|
self.binding_service = binding_service
|
||||||
self.runtime_settings_getter = runtime_settings_getter
|
self.runtime_settings_getter = runtime_settings_getter
|
||||||
self.scheduler = AsyncIOScheduler(timezone="UTC")
|
self.scheduler = AsyncIOScheduler(timezone="UTC")
|
||||||
|
self._leader_connection: AsyncConnection | None = None
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self.scheduler.running:
|
if self.scheduler.running:
|
||||||
return
|
return
|
||||||
|
if not await self._acquire_leader_lock():
|
||||||
|
logger.info("Archive scheduler leader lock not acquired; skipping local scheduler start.")
|
||||||
|
return
|
||||||
|
try:
|
||||||
self.scheduler.add_job(
|
self.scheduler.add_job(
|
||||||
self.archive_inactive_bindings,
|
self.archive_inactive_bindings,
|
||||||
trigger="interval",
|
trigger="interval",
|
||||||
@@ -43,10 +50,15 @@ class ArchiveService:
|
|||||||
coalesce=True,
|
coalesce=True,
|
||||||
)
|
)
|
||||||
self.scheduler.start()
|
self.scheduler.start()
|
||||||
|
except Exception:
|
||||||
|
await self._release_leader_lock()
|
||||||
|
raise
|
||||||
|
logger.info("Archive scheduler started on current worker.")
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
if self.scheduler.running:
|
if self.scheduler.running:
|
||||||
self.scheduler.shutdown(wait=False)
|
self.scheduler.shutdown(wait=False)
|
||||||
|
await self._release_leader_lock()
|
||||||
|
|
||||||
async def archive_inactive_bindings(self) -> int:
|
async def archive_inactive_bindings(self) -> int:
|
||||||
runtime_settings = self.runtime_settings_getter()
|
runtime_settings = self.runtime_settings_getter()
|
||||||
@@ -82,3 +94,43 @@ class ArchiveService:
|
|||||||
if total_archived:
|
if total_archived:
|
||||||
logger.info("Archived inactive bindings.", extra={"count": total_archived})
|
logger.info("Archived inactive bindings.", extra={"count": total_archived})
|
||||||
return total_archived
|
return total_archived
|
||||||
|
|
||||||
|
async def _acquire_leader_lock(self) -> bool:
|
||||||
|
if self._leader_connection is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
connection = await self.engine.connect()
|
||||||
|
try:
|
||||||
|
acquired = bool(
|
||||||
|
await connection.scalar(
|
||||||
|
text("SELECT pg_try_advisory_lock(:lock_key)"),
|
||||||
|
{"lock_key": self.settings.archive_scheduler_lock_key},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
await connection.close()
|
||||||
|
logger.exception("Failed to acquire archive scheduler leader lock.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not acquired:
|
||||||
|
await connection.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._leader_connection = connection
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _release_leader_lock(self) -> None:
|
||||||
|
if self._leader_connection is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
connection = self._leader_connection
|
||||||
|
self._leader_connection = None
|
||||||
|
try:
|
||||||
|
await connection.execute(
|
||||||
|
text("SELECT pg_advisory_unlock(:lock_key)"),
|
||||||
|
{"lock_key": self.settings.archive_scheduler_lock_key},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to release archive scheduler leader lock cleanly.")
|
||||||
|
finally:
|
||||||
|
await connection.close()
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
worker_processes auto;
|
worker_processes 8;
|
||||||
|
|
||||||
events {
|
events {
|
||||||
worker_connections 4096;
|
worker_connections 4096;
|
||||||
|
|||||||
Reference in New Issue
Block a user