Files
sentinel/app/main.py

198 lines
6.2 KiB
Python

from __future__ import annotations
import asyncio
import json
import logging
from contextlib import asynccontextmanager
from datetime import UTC, datetime
import httpx
from fastapi import FastAPI
from redis.asyncio import Redis
from redis.asyncio import from_url as redis_from_url
from app.api import auth, bindings, dashboard, logs, settings as settings_api
from app.config import RUNTIME_SETTINGS_REDIS_KEY, RuntimeSettings, Settings, get_settings
from app.models import intercept_log, token_binding # noqa: F401
from app.models.db import close_db, ensure_schema_compatibility, get_engine, get_session_factory, init_db
from app.proxy.handler import router as proxy_router
from app.services.alert_service import AlertService
from app.services.archive_service import ArchiveService
from app.services.binding_service import BindingService
class JsonFormatter(logging.Formatter):
reserved = {
"args",
"asctime",
"created",
"exc_info",
"exc_text",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"message",
"msg",
"name",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"thread",
"threadName",
}
def format(self, record: logging.LogRecord) -> str:
payload = {
"timestamp": datetime.now(UTC).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
for key, value in record.__dict__.items():
if key in self.reserved or key.startswith("_"):
continue
payload[key] = value
if record.exc_info:
payload["exception"] = self.formatException(record.exc_info)
return json.dumps(payload, default=str)
def configure_logging() -> None:
root_logger = logging.getLogger()
handler = logging.StreamHandler()
handler.setFormatter(JsonFormatter())
root_logger.handlers.clear()
root_logger.addHandler(handler)
root_logger.setLevel(logging.INFO)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
configure_logging()
logger = logging.getLogger(__name__)
async def load_runtime_settings(redis: Redis | None, settings: Settings) -> RuntimeSettings:
runtime_settings = settings.build_runtime_settings()
if redis is None:
return runtime_settings
try:
raw = await redis.hgetall(RUNTIME_SETTINGS_REDIS_KEY)
except Exception:
logger.warning("Failed to load runtime settings from Redis; using environment defaults.")
return runtime_settings
if not raw:
return runtime_settings
return RuntimeSettings(
alert_webhook_url=raw.get("alert_webhook_url") or None,
alert_threshold_count=int(raw.get("alert_threshold_count", runtime_settings.alert_threshold_count)),
alert_threshold_seconds=int(raw.get("alert_threshold_seconds", runtime_settings.alert_threshold_seconds)),
archive_days=int(raw.get("archive_days", runtime_settings.archive_days)),
failsafe_mode=raw.get("failsafe_mode", runtime_settings.failsafe_mode),
)
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = get_settings()
init_db(settings)
await ensure_schema_compatibility()
session_factory = get_session_factory()
redis: Redis | None = redis_from_url(
settings.redis_addr,
password=settings.redis_password or None,
encoding="utf-8",
decode_responses=True,
)
try:
await redis.ping()
except Exception:
logger.warning("Redis is unavailable at startup; continuing in degraded mode.")
try:
await redis.aclose()
except Exception:
pass
redis = None
downstream_client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=600.0, write=600.0, pool=10.0),
limits=httpx.Limits(
max_connections=settings.downstream_max_connections,
max_keepalive_connections=settings.downstream_max_keepalive_connections,
),
follow_redirects=False,
)
webhook_client = httpx.AsyncClient(timeout=httpx.Timeout(settings.webhook_timeout_seconds))
runtime_settings = await load_runtime_settings(redis, settings)
app.state.settings = settings
app.state.redis = redis
app.state.session_factory = session_factory
app.state.downstream_client = downstream_client
app.state.webhook_client = webhook_client
app.state.runtime_settings = runtime_settings
app.state.runtime_settings_lock = asyncio.Lock()
binding_service = BindingService(
settings=settings,
session_factory=session_factory,
redis=redis,
runtime_settings_getter=lambda: app.state.runtime_settings,
)
alert_service = AlertService(
settings=settings,
session_factory=session_factory,
redis=redis,
http_client=webhook_client,
runtime_settings_getter=lambda: app.state.runtime_settings,
)
archive_service = ArchiveService(
settings=settings,
engine=get_engine(),
session_factory=session_factory,
binding_service=binding_service,
runtime_settings_getter=lambda: app.state.runtime_settings,
)
app.state.binding_service = binding_service
app.state.alert_service = alert_service
app.state.archive_service = archive_service
await binding_service.start()
await archive_service.start()
logger.info("Application started.")
try:
yield
finally:
await archive_service.stop()
await binding_service.stop()
await downstream_client.aclose()
await webhook_client.aclose()
if redis is not None:
await redis.aclose()
await close_db()
logger.info("Application stopped.")
app = FastAPI(title="Key-IP Sentinel", lifespan=lifespan)
app.include_router(auth.router)
app.include_router(dashboard.router)
app.include_router(bindings.router)
app.include_router(logs.router)
app.include_router(settings_api.router)
@app.get("/health")
async def health() -> dict[str, str]:
return {"status": "ok"}
app.include_router(proxy_router)