feat(core): 初始化 Key-IP Sentinel 服务与部署骨架
- 搭建 FastAPI、Redis、PostgreSQL、Nginx 与 Docker Compose 基础结构 - 实现反向代理、首用绑定、拦截告警、归档任务和管理接口 - 提供 Vue3 管理后台初版,以及 uv/requirements 双依赖配置
This commit is contained in:
49
app/api/auth.py
Normal file
49
app/api/auth.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.config import Settings
|
||||
from app.core.ip_utils import extract_client_ip
|
||||
from app.core.security import (
|
||||
clear_login_failures,
|
||||
create_admin_jwt,
|
||||
ensure_login_allowed,
|
||||
register_login_failure,
|
||||
verify_admin_password,
|
||||
)
|
||||
from app.dependencies import get_redis, get_settings
|
||||
from app.schemas.auth import LoginRequest, TokenResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/api", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(
|
||||
payload: LoginRequest,
|
||||
request: Request,
|
||||
settings: Settings = Depends(get_settings),
|
||||
redis: Redis | None = Depends(get_redis),
|
||||
) -> TokenResponse:
|
||||
if redis is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Login service is unavailable because Redis is offline.",
|
||||
)
|
||||
|
||||
client_ip = extract_client_ip(request, settings)
|
||||
await ensure_login_allowed(redis, client_ip, settings)
|
||||
|
||||
if not verify_admin_password(payload.password, settings):
|
||||
await register_login_failure(redis, client_ip, settings)
|
||||
logger.warning("Admin login failed.", extra={"client_ip": client_ip})
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid admin password.")
|
||||
|
||||
await clear_login_failures(redis, client_ip)
|
||||
token, expires_in = create_admin_jwt(settings)
|
||||
logger.info("Admin login succeeded.", extra={"client_ip": client_ip})
|
||||
return TokenResponse(access_token=token, expires_in=expires_in)
|
||||
153
app/api/bindings.py
Normal file
153
app/api/bindings.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from sqlalchemy import String, cast, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import Settings
|
||||
from app.core.ip_utils import extract_client_ip
|
||||
from app.dependencies import get_binding_service, get_db_session, get_settings, require_admin
|
||||
from app.models.token_binding import STATUS_ACTIVE, STATUS_BANNED, TokenBinding
|
||||
from app.schemas.binding import (
|
||||
BindingActionRequest,
|
||||
BindingIPUpdateRequest,
|
||||
BindingItem,
|
||||
BindingListResponse,
|
||||
)
|
||||
from app.services.binding_service import BindingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/api/bindings", tags=["bindings"], dependencies=[Depends(require_admin)])
|
||||
|
||||
|
||||
def to_binding_item(binding: TokenBinding, binding_service: BindingService) -> BindingItem:
|
||||
return BindingItem(
|
||||
id=binding.id,
|
||||
token_display=binding.token_display,
|
||||
bound_ip=str(binding.bound_ip),
|
||||
status=binding.status,
|
||||
status_label=binding_service.status_label(binding.status),
|
||||
first_used_at=binding.first_used_at,
|
||||
last_used_at=binding.last_used_at,
|
||||
created_at=binding.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def get_binding_or_404(session: AsyncSession, binding_id: int) -> TokenBinding:
|
||||
binding = await session.get(TokenBinding, binding_id)
|
||||
if binding is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Binding was not found.")
|
||||
return binding
|
||||
|
||||
|
||||
def log_admin_action(request: Request, settings: Settings, action: str, binding_id: int) -> None:
|
||||
logger.info(
|
||||
"Admin binding action.",
|
||||
extra={
|
||||
"client_ip": extract_client_ip(request, settings),
|
||||
"action": action,
|
||||
"binding_id": binding_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=BindingListResponse)
|
||||
async def list_bindings(
|
||||
page: int = Query(default=1, ge=1),
|
||||
page_size: int = Query(default=20, ge=1, le=200),
|
||||
token_suffix: str | None = Query(default=None),
|
||||
ip: str | None = Query(default=None),
|
||||
status_filter: int | None = Query(default=None, alias="status"),
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
binding_service: BindingService = Depends(get_binding_service),
|
||||
) -> BindingListResponse:
|
||||
statement = select(TokenBinding)
|
||||
if token_suffix:
|
||||
statement = statement.where(TokenBinding.token_display.ilike(f"%{token_suffix}%"))
|
||||
if ip:
|
||||
statement = statement.where(cast(TokenBinding.bound_ip, String).ilike(f"%{ip}%"))
|
||||
if status_filter in {STATUS_ACTIVE, STATUS_BANNED}:
|
||||
statement = statement.where(TokenBinding.status == status_filter)
|
||||
|
||||
total_result = await session.execute(select(func.count()).select_from(statement.subquery()))
|
||||
total = int(total_result.scalar_one())
|
||||
bindings = (
|
||||
await session.scalars(
|
||||
statement.order_by(TokenBinding.last_used_at.desc()).offset((page - 1) * page_size).limit(page_size)
|
||||
)
|
||||
).all()
|
||||
|
||||
return BindingListResponse(
|
||||
items=[to_binding_item(item, binding_service) for item in bindings],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/unbind")
|
||||
async def unbind_token(
|
||||
payload: BindingActionRequest,
|
||||
request: Request,
|
||||
settings: Settings = Depends(get_settings),
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
binding_service: BindingService = Depends(get_binding_service),
|
||||
):
|
||||
binding = await get_binding_or_404(session, payload.id)
|
||||
token_hash = binding.token_hash
|
||||
await session.delete(binding)
|
||||
await session.commit()
|
||||
await binding_service.invalidate_binding_cache(token_hash)
|
||||
log_admin_action(request, settings, "unbind", payload.id)
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.put("/ip")
|
||||
async def update_bound_ip(
|
||||
payload: BindingIPUpdateRequest,
|
||||
request: Request,
|
||||
settings: Settings = Depends(get_settings),
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
binding_service: BindingService = Depends(get_binding_service),
|
||||
):
|
||||
binding = await get_binding_or_404(session, payload.id)
|
||||
binding.bound_ip = payload.bound_ip
|
||||
await session.commit()
|
||||
await binding_service.sync_binding_cache(binding.token_hash, str(binding.bound_ip), binding.status)
|
||||
log_admin_action(request, settings, "update_ip", payload.id)
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.post("/ban")
|
||||
async def ban_token(
|
||||
payload: BindingActionRequest,
|
||||
request: Request,
|
||||
settings: Settings = Depends(get_settings),
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
binding_service: BindingService = Depends(get_binding_service),
|
||||
):
|
||||
binding = await get_binding_or_404(session, payload.id)
|
||||
binding.status = STATUS_BANNED
|
||||
await session.commit()
|
||||
await binding_service.sync_binding_cache(binding.token_hash, str(binding.bound_ip), binding.status)
|
||||
log_admin_action(request, settings, "ban", payload.id)
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@router.post("/unban")
|
||||
async def unban_token(
|
||||
payload: BindingActionRequest,
|
||||
request: Request,
|
||||
settings: Settings = Depends(get_settings),
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
binding_service: BindingService = Depends(get_binding_service),
|
||||
):
|
||||
binding = await get_binding_or_404(session, payload.id)
|
||||
binding.status = STATUS_ACTIVE
|
||||
await session.commit()
|
||||
await binding_service.sync_binding_cache(binding.token_hash, str(binding.bound_ip), binding.status)
|
||||
log_admin_action(request, settings, "unban", payload.id)
|
||||
return {"success": True}
|
||||
109
app/api/dashboard.py
Normal file
109
app/api/dashboard.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, time, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_binding_service, get_db_session, require_admin
|
||||
from app.models.intercept_log import InterceptLog
|
||||
from app.models.token_binding import STATUS_ACTIVE, STATUS_BANNED, TokenBinding
|
||||
from app.schemas.log import InterceptLogItem
|
||||
from app.services.binding_service import BindingService
|
||||
|
||||
router = APIRouter(prefix="/admin/api", tags=["dashboard"], dependencies=[Depends(require_admin)])
|
||||
|
||||
|
||||
class MetricSummary(BaseModel):
|
||||
total: int
|
||||
allowed: int
|
||||
intercepted: int
|
||||
|
||||
|
||||
class BindingSummary(BaseModel):
|
||||
active: int
|
||||
banned: int
|
||||
|
||||
|
||||
class TrendPoint(BaseModel):
|
||||
date: str
|
||||
total: int
|
||||
allowed: int
|
||||
intercepted: int
|
||||
|
||||
|
||||
class DashboardResponse(BaseModel):
|
||||
today: MetricSummary
|
||||
bindings: BindingSummary
|
||||
trend: list[TrendPoint]
|
||||
recent_intercepts: list[InterceptLogItem]
|
||||
|
||||
|
||||
async def build_trend(
|
||||
session: AsyncSession,
|
||||
binding_service: BindingService,
|
||||
) -> list[TrendPoint]:
|
||||
series = await binding_service.get_metrics_window(days=7)
|
||||
start_day = datetime.combine(datetime.now(UTC).date() - timedelta(days=6), time.min, tzinfo=UTC)
|
||||
|
||||
intercept_counts_result = await session.execute(
|
||||
select(func.date(InterceptLog.intercepted_at), func.count())
|
||||
.where(InterceptLog.intercepted_at >= start_day)
|
||||
.group_by(func.date(InterceptLog.intercepted_at))
|
||||
)
|
||||
db_intercept_counts = {
|
||||
row[0].isoformat(): int(row[1])
|
||||
for row in intercept_counts_result.all()
|
||||
}
|
||||
|
||||
trend: list[TrendPoint] = []
|
||||
for item in series:
|
||||
day = str(item["date"])
|
||||
allowed = int(item["allowed"])
|
||||
intercepted = max(int(item["intercepted"]), db_intercept_counts.get(day, 0))
|
||||
total = max(int(item["total"]), allowed + intercepted)
|
||||
trend.append(TrendPoint(date=day, total=total, allowed=allowed, intercepted=intercepted))
|
||||
return trend
|
||||
|
||||
|
||||
async def build_recent_intercepts(session: AsyncSession) -> list[InterceptLogItem]:
|
||||
recent_logs = (
|
||||
await session.scalars(select(InterceptLog).order_by(InterceptLog.intercepted_at.desc()).limit(10))
|
||||
).all()
|
||||
return [
|
||||
InterceptLogItem(
|
||||
id=item.id,
|
||||
token_display=item.token_display,
|
||||
bound_ip=str(item.bound_ip),
|
||||
attempt_ip=str(item.attempt_ip),
|
||||
alerted=item.alerted,
|
||||
intercepted_at=item.intercepted_at,
|
||||
)
|
||||
for item in recent_logs
|
||||
]
|
||||
|
||||
|
||||
@router.get("/dashboard", response_model=DashboardResponse)
|
||||
async def get_dashboard(
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
binding_service: BindingService = Depends(get_binding_service),
|
||||
) -> DashboardResponse:
|
||||
trend = await build_trend(session, binding_service)
|
||||
|
||||
active_count = await session.scalar(
|
||||
select(func.count()).select_from(TokenBinding).where(TokenBinding.status == STATUS_ACTIVE)
|
||||
)
|
||||
banned_count = await session.scalar(
|
||||
select(func.count()).select_from(TokenBinding).where(TokenBinding.status == STATUS_BANNED)
|
||||
)
|
||||
recent_intercepts = await build_recent_intercepts(session)
|
||||
|
||||
today = trend[-1] if trend else TrendPoint(date=datetime.now(UTC).date().isoformat(), total=0, allowed=0, intercepted=0)
|
||||
return DashboardResponse(
|
||||
today=MetricSummary(total=today.total, allowed=today.allowed, intercepted=today.intercepted),
|
||||
bindings=BindingSummary(active=int(active_count or 0), banned=int(banned_count or 0)),
|
||||
trend=trend,
|
||||
recent_intercepts=recent_intercepts,
|
||||
)
|
||||
107
app/api/logs.py
Normal file
107
app/api/logs.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import String, cast, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.dependencies import get_db_session, require_admin
|
||||
from app.models.intercept_log import InterceptLog
|
||||
from app.schemas.log import InterceptLogItem, LogListResponse
|
||||
|
||||
router = APIRouter(prefix="/admin/api/logs", tags=["logs"], dependencies=[Depends(require_admin)])
|
||||
|
||||
|
||||
def apply_log_filters(
|
||||
statement,
|
||||
token: str | None,
|
||||
attempt_ip: str | None,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
):
|
||||
if token:
|
||||
statement = statement.where(InterceptLog.token_display.ilike(f"%{token}%"))
|
||||
if attempt_ip:
|
||||
statement = statement.where(cast(InterceptLog.attempt_ip, String).ilike(f"%{attempt_ip}%"))
|
||||
if start_time:
|
||||
statement = statement.where(InterceptLog.intercepted_at >= start_time)
|
||||
if end_time:
|
||||
statement = statement.where(InterceptLog.intercepted_at <= end_time)
|
||||
return statement
|
||||
|
||||
|
||||
@router.get("", response_model=LogListResponse)
|
||||
async def list_logs(
|
||||
page: int = Query(default=1, ge=1),
|
||||
page_size: int = Query(default=20, ge=1, le=200),
|
||||
token: str | None = Query(default=None),
|
||||
attempt_ip: str | None = Query(default=None),
|
||||
start_time: datetime | None = Query(default=None),
|
||||
end_time: datetime | None = Query(default=None),
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
) -> LogListResponse:
|
||||
statement = apply_log_filters(select(InterceptLog), token, attempt_ip, start_time, end_time)
|
||||
total_result = await session.execute(select(func.count()).select_from(statement.subquery()))
|
||||
total = int(total_result.scalar_one())
|
||||
logs = (
|
||||
await session.scalars(
|
||||
statement.order_by(InterceptLog.intercepted_at.desc()).offset((page - 1) * page_size).limit(page_size)
|
||||
)
|
||||
).all()
|
||||
|
||||
return LogListResponse(
|
||||
items=[
|
||||
InterceptLogItem(
|
||||
id=item.id,
|
||||
token_display=item.token_display,
|
||||
bound_ip=str(item.bound_ip),
|
||||
attempt_ip=str(item.attempt_ip),
|
||||
alerted=item.alerted,
|
||||
intercepted_at=item.intercepted_at,
|
||||
)
|
||||
for item in logs
|
||||
],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/export")
|
||||
async def export_logs(
|
||||
token: str | None = Query(default=None),
|
||||
attempt_ip: str | None = Query(default=None),
|
||||
start_time: datetime | None = Query(default=None),
|
||||
end_time: datetime | None = Query(default=None),
|
||||
session: AsyncSession = Depends(get_db_session),
|
||||
):
|
||||
statement = apply_log_filters(select(InterceptLog), token, attempt_ip, start_time, end_time).order_by(
|
||||
InterceptLog.intercepted_at.desc()
|
||||
)
|
||||
logs = (await session.scalars(statement)).all()
|
||||
|
||||
buffer = io.StringIO()
|
||||
writer = csv.writer(buffer)
|
||||
writer.writerow(["id", "token_display", "bound_ip", "attempt_ip", "alerted", "intercepted_at"])
|
||||
for item in logs:
|
||||
writer.writerow(
|
||||
[
|
||||
item.id,
|
||||
item.token_display,
|
||||
str(item.bound_ip),
|
||||
str(item.attempt_ip),
|
||||
item.alerted,
|
||||
item.intercepted_at.isoformat(),
|
||||
]
|
||||
)
|
||||
|
||||
filename = f"sentinel-logs-{datetime.utcnow().strftime('%Y%m%d%H%M%S')}.csv"
|
||||
return StreamingResponse(
|
||||
iter([buffer.getvalue()]),
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
77
app/api/settings.py
Normal file
77
app/api/settings.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel, Field
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.config import RUNTIME_SETTINGS_REDIS_KEY, RuntimeSettings, Settings
|
||||
from app.core.ip_utils import extract_client_ip
|
||||
from app.dependencies import get_redis, get_runtime_settings, get_settings, require_admin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/api/settings", tags=["settings"], dependencies=[Depends(require_admin)])
|
||||
|
||||
|
||||
class SettingsResponse(BaseModel):
|
||||
alert_webhook_url: str | None = None
|
||||
alert_threshold_count: int = Field(ge=1)
|
||||
alert_threshold_seconds: int = Field(ge=1)
|
||||
archive_days: int = Field(ge=1)
|
||||
failsafe_mode: Literal["open", "closed"]
|
||||
|
||||
|
||||
class SettingsUpdateRequest(SettingsResponse):
|
||||
pass
|
||||
|
||||
|
||||
def serialize_runtime_settings(runtime_settings: RuntimeSettings) -> dict[str, str]:
|
||||
return {
|
||||
"alert_webhook_url": runtime_settings.alert_webhook_url or "",
|
||||
"alert_threshold_count": str(runtime_settings.alert_threshold_count),
|
||||
"alert_threshold_seconds": str(runtime_settings.alert_threshold_seconds),
|
||||
"archive_days": str(runtime_settings.archive_days),
|
||||
"failsafe_mode": runtime_settings.failsafe_mode,
|
||||
}
|
||||
|
||||
|
||||
@router.get("", response_model=SettingsResponse)
|
||||
async def get_runtime_config(
|
||||
runtime_settings: RuntimeSettings = Depends(get_runtime_settings),
|
||||
) -> SettingsResponse:
|
||||
return SettingsResponse(**runtime_settings.model_dump())
|
||||
|
||||
|
||||
@router.put("", response_model=SettingsResponse)
|
||||
async def update_runtime_config(
|
||||
payload: SettingsUpdateRequest,
|
||||
request: Request,
|
||||
settings: Settings = Depends(get_settings),
|
||||
redis: Redis | None = Depends(get_redis),
|
||||
):
|
||||
if redis is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Settings persistence is unavailable because Redis is offline.",
|
||||
)
|
||||
|
||||
updated = RuntimeSettings(**payload.model_dump())
|
||||
try:
|
||||
await redis.hset(RUNTIME_SETTINGS_REDIS_KEY, mapping=serialize_runtime_settings(updated))
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Failed to persist runtime settings.",
|
||||
) from exc
|
||||
|
||||
async with request.app.state.runtime_settings_lock:
|
||||
request.app.state.runtime_settings = updated
|
||||
|
||||
logger.info(
|
||||
"Runtime settings updated.",
|
||||
extra={"client_ip": extract_client_ip(request, settings)},
|
||||
)
|
||||
return SettingsResponse(**updated.model_dump())
|
||||
98
app/config.py
Normal file
98
app/config.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
from ipaddress import ip_network
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
RUNTIME_SETTINGS_REDIS_KEY = "sentinel:settings"
|
||||
|
||||
|
||||
class RuntimeSettings(BaseModel):
|
||||
alert_webhook_url: str | None = None
|
||||
alert_threshold_count: int = Field(default=5, ge=1)
|
||||
alert_threshold_seconds: int = Field(default=300, ge=1)
|
||||
archive_days: int = Field(default=90, ge=1)
|
||||
failsafe_mode: Literal["open", "closed"] = "closed"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
case_sensitive=False,
|
||||
)
|
||||
|
||||
downstream_url: str = Field(alias="DOWNSTREAM_URL")
|
||||
redis_addr: str = Field(alias="REDIS_ADDR")
|
||||
redis_password: str = Field(default="", alias="REDIS_PASSWORD")
|
||||
pg_dsn: str = Field(alias="PG_DSN")
|
||||
sentinel_hmac_secret: str = Field(alias="SENTINEL_HMAC_SECRET", min_length=32)
|
||||
admin_password: str = Field(alias="ADMIN_PASSWORD", min_length=8)
|
||||
admin_jwt_secret: str = Field(alias="ADMIN_JWT_SECRET", min_length=16)
|
||||
trusted_proxy_ips: tuple[str, ...] = Field(default_factory=tuple, alias="TRUSTED_PROXY_IPS")
|
||||
sentinel_failsafe_mode: Literal["open", "closed"] = Field(
|
||||
default="closed",
|
||||
alias="SENTINEL_FAILSAFE_MODE",
|
||||
)
|
||||
app_port: int = Field(default=7000, alias="APP_PORT")
|
||||
alert_webhook_url: str | None = Field(default=None, alias="ALERT_WEBHOOK_URL")
|
||||
alert_threshold_count: int = Field(default=5, alias="ALERT_THRESHOLD_COUNT", ge=1)
|
||||
alert_threshold_seconds: int = Field(default=300, alias="ALERT_THRESHOLD_SECONDS", ge=1)
|
||||
archive_days: int = Field(default=90, alias="ARCHIVE_DAYS", ge=1)
|
||||
|
||||
redis_binding_ttl_seconds: int = 604800
|
||||
downstream_max_connections: int = 512
|
||||
downstream_max_keepalive_connections: int = 128
|
||||
last_used_flush_interval_seconds: int = 5
|
||||
last_used_queue_size: int = 10000
|
||||
login_lockout_threshold: int = 5
|
||||
login_lockout_seconds: int = 900
|
||||
admin_jwt_expire_hours: int = 8
|
||||
archive_job_interval_minutes: int = 60
|
||||
archive_batch_size: int = 500
|
||||
metrics_ttl_days: int = 30
|
||||
webhook_timeout_seconds: int = 5
|
||||
|
||||
@field_validator("downstream_url")
|
||||
@classmethod
|
||||
def normalize_downstream_url(cls, value: str) -> str:
|
||||
return value.rstrip("/")
|
||||
|
||||
@field_validator("trusted_proxy_ips", mode="before")
|
||||
@classmethod
|
||||
def split_proxy_ips(cls, value: object) -> tuple[str, ...]:
|
||||
if value is None:
|
||||
return tuple()
|
||||
if isinstance(value, str):
|
||||
parts = [item.strip() for item in value.split(",")]
|
||||
return tuple(item for item in parts if item)
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return tuple(str(item).strip() for item in value if str(item).strip())
|
||||
return (str(value).strip(),)
|
||||
|
||||
@cached_property
|
||||
def trusted_proxy_networks(self):
|
||||
return tuple(ip_network(item, strict=False) for item in self.trusted_proxy_ips)
|
||||
|
||||
def build_runtime_settings(self) -> RuntimeSettings:
|
||||
return RuntimeSettings(
|
||||
alert_webhook_url=self.alert_webhook_url or None,
|
||||
alert_threshold_count=self.alert_threshold_count,
|
||||
alert_threshold_seconds=self.alert_threshold_seconds,
|
||||
archive_days=self.archive_days,
|
||||
failsafe_mode=self.sentinel_failsafe_mode,
|
||||
)
|
||||
|
||||
|
||||
_settings: Settings | None = None
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = Settings()
|
||||
return _settings
|
||||
35
app/core/ip_utils.py
Normal file
35
app/core/ip_utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import ip_address, ip_network
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
|
||||
def is_ip_in_network(candidate_ip: str, network_value: str) -> bool:
|
||||
return ip_address(candidate_ip) in ip_network(network_value, strict=False)
|
||||
|
||||
|
||||
def is_trusted_proxy(source_ip: str, settings: Settings) -> bool:
|
||||
try:
|
||||
parsed_ip = ip_address(source_ip)
|
||||
except ValueError:
|
||||
return False
|
||||
return any(parsed_ip in network for network in settings.trusted_proxy_networks)
|
||||
|
||||
|
||||
def extract_client_ip(request: Request, settings: Settings) -> str:
|
||||
client_host = request.client.host if request.client else "127.0.0.1"
|
||||
if not is_trusted_proxy(client_host, settings):
|
||||
return client_host
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if not real_ip:
|
||||
return client_host
|
||||
|
||||
try:
|
||||
ip_address(real_ip)
|
||||
except ValueError:
|
||||
return client_host
|
||||
return real_ip
|
||||
104
app/core/security.py
Normal file
104
app/core/security.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from jose import JWTError, jwt
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
def mask_token(token: str) -> str:
|
||||
if not token:
|
||||
return "unknown"
|
||||
if len(token) <= 8:
|
||||
return f"{token[:2]}...{token[-2:]}"
|
||||
return f"{token[:4]}...{token[-4:]}"[:20]
|
||||
|
||||
|
||||
def hash_token(token: str, secret: str) -> str:
|
||||
return hmac.new(secret.encode("utf-8"), token.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
|
||||
|
||||
def extract_bearer_token(authorization: str | None) -> str | None:
|
||||
if not authorization:
|
||||
return None
|
||||
scheme, _, token = authorization.partition(" ")
|
||||
if scheme.lower() != "bearer" or not token:
|
||||
return None
|
||||
return token.strip()
|
||||
|
||||
|
||||
def verify_admin_password(password: str, settings: Settings) -> bool:
|
||||
return hmac.compare_digest(password, settings.admin_password)
|
||||
|
||||
|
||||
def create_admin_jwt(settings: Settings) -> tuple[str, int]:
|
||||
expires_in = settings.admin_jwt_expire_hours * 3600
|
||||
now = datetime.now(UTC)
|
||||
payload = {
|
||||
"sub": "admin",
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int((now + timedelta(seconds=expires_in)).timestamp()),
|
||||
}
|
||||
token = jwt.encode(payload, settings.admin_jwt_secret, algorithm=ALGORITHM)
|
||||
return token, expires_in
|
||||
|
||||
|
||||
def decode_admin_jwt(token: str, settings: Settings) -> dict:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.admin_jwt_secret, algorithms=[ALGORITHM])
|
||||
except JWTError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired admin token.",
|
||||
) from exc
|
||||
|
||||
if payload.get("sub") != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid admin token subject.",
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
def login_failure_key(client_ip: str) -> str:
|
||||
return f"sentinel:login:fail:{client_ip}"
|
||||
|
||||
|
||||
async def ensure_login_allowed(redis: Redis, client_ip: str, settings: Settings) -> None:
|
||||
try:
|
||||
current = await redis.get(login_failure_key(client_ip))
|
||||
if current is None:
|
||||
return
|
||||
if int(current) >= settings.login_lockout_threshold:
|
||||
ttl = await redis.ttl(login_failure_key(client_ip))
|
||||
retry_after = max(ttl, 0)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"Too many failed login attempts. Retry after {retry_after} seconds.",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Login lock service is unavailable.",
|
||||
) from exc
|
||||
|
||||
|
||||
async def register_login_failure(redis: Redis, client_ip: str, settings: Settings) -> None:
|
||||
key = login_failure_key(client_ip)
|
||||
async with redis.pipeline(transaction=True) as pipeline:
|
||||
pipeline.incr(key)
|
||||
pipeline.expire(key, settings.login_lockout_seconds)
|
||||
await pipeline.execute()
|
||||
|
||||
|
||||
async def clear_login_failures(redis: Redis, client_ip: str) -> None:
|
||||
await redis.delete(login_failure_key(client_ip))
|
||||
53
app/dependencies.py
Normal file
53
app/dependencies.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import RuntimeSettings, Settings
|
||||
from app.core.security import decode_admin_jwt, extract_bearer_token
|
||||
from app.services.alert_service import AlertService
|
||||
from app.services.archive_service import ArchiveService
|
||||
from app.services.binding_service import BindingService
|
||||
|
||||
|
||||
def get_settings(request: Request) -> Settings:
|
||||
return request.app.state.settings
|
||||
|
||||
|
||||
def get_redis(request: Request) -> Redis | None:
|
||||
return request.app.state.redis
|
||||
|
||||
|
||||
async def get_db_session(request: Request) -> AsyncIterator[AsyncSession]:
|
||||
session_factory = request.app.state.session_factory
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
def get_binding_service(request: Request) -> BindingService:
|
||||
return request.app.state.binding_service
|
||||
|
||||
|
||||
def get_alert_service(request: Request) -> AlertService:
|
||||
return request.app.state.alert_service
|
||||
|
||||
|
||||
def get_archive_service(request: Request) -> ArchiveService:
|
||||
return request.app.state.archive_service
|
||||
|
||||
|
||||
def get_runtime_settings(request: Request) -> RuntimeSettings:
|
||||
return request.app.state.runtime_settings
|
||||
|
||||
|
||||
async def require_admin(request: Request, settings: Settings = Depends(get_settings)) -> dict:
|
||||
token = extract_bearer_token(request.headers.get("authorization"))
|
||||
if token is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing admin bearer token.",
|
||||
)
|
||||
return decode_admin_jwt(token, settings)
|
||||
193
app/main.py
Normal file
193
app/main.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI
|
||||
from redis.asyncio import Redis
|
||||
from redis.asyncio import from_url as redis_from_url
|
||||
|
||||
from app.api import auth, bindings, dashboard, logs, settings as settings_api
|
||||
from app.config import RUNTIME_SETTINGS_REDIS_KEY, RuntimeSettings, Settings, get_settings
|
||||
from app.models import intercept_log, token_binding # noqa: F401
|
||||
from app.models.db import close_db, get_session_factory, init_db
|
||||
from app.proxy.handler import router as proxy_router
|
||||
from app.services.alert_service import AlertService
|
||||
from app.services.archive_service import ArchiveService
|
||||
from app.services.binding_service import BindingService
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
reserved = {
|
||||
"args",
|
||||
"asctime",
|
||||
"created",
|
||||
"exc_info",
|
||||
"exc_text",
|
||||
"filename",
|
||||
"funcName",
|
||||
"levelname",
|
||||
"levelno",
|
||||
"lineno",
|
||||
"module",
|
||||
"msecs",
|
||||
"message",
|
||||
"msg",
|
||||
"name",
|
||||
"pathname",
|
||||
"process",
|
||||
"processName",
|
||||
"relativeCreated",
|
||||
"stack_info",
|
||||
"thread",
|
||||
"threadName",
|
||||
}
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
payload = {
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"message": record.getMessage(),
|
||||
}
|
||||
for key, value in record.__dict__.items():
|
||||
if key in self.reserved or key.startswith("_"):
|
||||
continue
|
||||
payload[key] = value
|
||||
if record.exc_info:
|
||||
payload["exception"] = self.formatException(record.exc_info)
|
||||
return json.dumps(payload, default=str)
|
||||
|
||||
|
||||
def configure_logging() -> None:
|
||||
root_logger = logging.getLogger()
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(JsonFormatter())
|
||||
root_logger.handlers.clear()
|
||||
root_logger.addHandler(handler)
|
||||
root_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
configure_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def load_runtime_settings(redis: Redis | None, settings: Settings) -> RuntimeSettings:
|
||||
runtime_settings = settings.build_runtime_settings()
|
||||
if redis is None:
|
||||
return runtime_settings
|
||||
try:
|
||||
raw = await redis.hgetall(RUNTIME_SETTINGS_REDIS_KEY)
|
||||
except Exception:
|
||||
logger.warning("Failed to load runtime settings from Redis; using environment defaults.")
|
||||
return runtime_settings
|
||||
if not raw:
|
||||
return runtime_settings
|
||||
return RuntimeSettings(
|
||||
alert_webhook_url=raw.get("alert_webhook_url") or None,
|
||||
alert_threshold_count=int(raw.get("alert_threshold_count", runtime_settings.alert_threshold_count)),
|
||||
alert_threshold_seconds=int(raw.get("alert_threshold_seconds", runtime_settings.alert_threshold_seconds)),
|
||||
archive_days=int(raw.get("archive_days", runtime_settings.archive_days)),
|
||||
failsafe_mode=raw.get("failsafe_mode", runtime_settings.failsafe_mode),
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
settings = get_settings()
|
||||
init_db(settings)
|
||||
session_factory = get_session_factory()
|
||||
|
||||
redis: Redis | None = redis_from_url(
|
||||
settings.redis_addr,
|
||||
password=settings.redis_password or None,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
try:
|
||||
await redis.ping()
|
||||
except Exception:
|
||||
logger.warning("Redis is unavailable at startup; continuing in degraded mode.")
|
||||
try:
|
||||
await redis.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
redis = None
|
||||
|
||||
downstream_client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=10.0, read=600.0, write=600.0, pool=10.0),
|
||||
limits=httpx.Limits(
|
||||
max_connections=settings.downstream_max_connections,
|
||||
max_keepalive_connections=settings.downstream_max_keepalive_connections,
|
||||
),
|
||||
follow_redirects=False,
|
||||
)
|
||||
webhook_client = httpx.AsyncClient(timeout=httpx.Timeout(settings.webhook_timeout_seconds))
|
||||
|
||||
runtime_settings = await load_runtime_settings(redis, settings)
|
||||
app.state.settings = settings
|
||||
app.state.redis = redis
|
||||
app.state.session_factory = session_factory
|
||||
app.state.downstream_client = downstream_client
|
||||
app.state.webhook_client = webhook_client
|
||||
app.state.runtime_settings = runtime_settings
|
||||
app.state.runtime_settings_lock = asyncio.Lock()
|
||||
|
||||
binding_service = BindingService(
|
||||
settings=settings,
|
||||
session_factory=session_factory,
|
||||
redis=redis,
|
||||
runtime_settings_getter=lambda: app.state.runtime_settings,
|
||||
)
|
||||
alert_service = AlertService(
|
||||
settings=settings,
|
||||
session_factory=session_factory,
|
||||
redis=redis,
|
||||
http_client=webhook_client,
|
||||
runtime_settings_getter=lambda: app.state.runtime_settings,
|
||||
)
|
||||
archive_service = ArchiveService(
|
||||
settings=settings,
|
||||
session_factory=session_factory,
|
||||
binding_service=binding_service,
|
||||
runtime_settings_getter=lambda: app.state.runtime_settings,
|
||||
)
|
||||
app.state.binding_service = binding_service
|
||||
app.state.alert_service = alert_service
|
||||
app.state.archive_service = archive_service
|
||||
|
||||
await binding_service.start()
|
||||
await archive_service.start()
|
||||
logger.info("Application started.")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
await archive_service.stop()
|
||||
await binding_service.stop()
|
||||
await downstream_client.aclose()
|
||||
await webhook_client.aclose()
|
||||
if redis is not None:
|
||||
await redis.aclose()
|
||||
await close_db()
|
||||
logger.info("Application stopped.")
|
||||
|
||||
|
||||
app = FastAPI(title="Key-IP Sentinel", lifespan=lifespan)
|
||||
|
||||
app.include_router(auth.router)
|
||||
app.include_router(dashboard.router)
|
||||
app.include_router(bindings.router)
|
||||
app.include_router(logs.router)
|
||||
app.include_router(settings_api.router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
app.include_router(proxy_router)
|
||||
48
app/models/db.py
Normal file
48
app/models/db.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
_engine: AsyncEngine | None = None
|
||||
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
|
||||
def init_db(settings: Settings) -> None:
|
||||
global _engine, _session_factory
|
||||
if _engine is not None and _session_factory is not None:
|
||||
return
|
||||
|
||||
_engine = create_async_engine(
|
||||
settings.pg_dsn,
|
||||
pool_pre_ping=True,
|
||||
pool_size=20,
|
||||
max_overflow=40,
|
||||
)
|
||||
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
|
||||
|
||||
|
||||
def get_engine() -> AsyncEngine:
|
||||
if _engine is None:
|
||||
raise RuntimeError("Database engine has not been initialized.")
|
||||
return _engine
|
||||
|
||||
|
||||
def get_session_factory() -> async_sessionmaker[AsyncSession]:
|
||||
if _session_factory is None:
|
||||
raise RuntimeError("Database session factory has not been initialized.")
|
||||
return _session_factory
|
||||
|
||||
|
||||
async def close_db() -> None:
|
||||
global _engine, _session_factory
|
||||
if _engine is not None:
|
||||
await _engine.dispose()
|
||||
_engine = None
|
||||
_session_factory = None
|
||||
29
app/models/intercept_log.py
Normal file
29
app/models/intercept_log.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Index, String, func, text
|
||||
from sqlalchemy.dialects.postgresql import CIDR, INET
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.db import Base
|
||||
|
||||
|
||||
class InterceptLog(Base):
|
||||
__tablename__ = "intercept_logs"
|
||||
__table_args__ = (
|
||||
Index("idx_intercept_logs_hash", "token_hash"),
|
||||
Index("idx_intercept_logs_time", text("intercepted_at DESC")),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
token_hash: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
token_display: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
bound_ip: Mapped[str] = mapped_column(CIDR, nullable=False)
|
||||
attempt_ip: Mapped[str] = mapped_column(INET, nullable=False)
|
||||
alerted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default=text("FALSE"))
|
||||
intercepted_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
)
|
||||
46
app/models/token_binding.py
Normal file
46
app/models/token_binding.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, Index, SmallInteger, String, func, text
|
||||
from sqlalchemy.dialects.postgresql import CIDR
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.db import Base
|
||||
|
||||
STATUS_ACTIVE = 1
|
||||
STATUS_BANNED = 2
|
||||
|
||||
|
||||
class TokenBinding(Base):
|
||||
__tablename__ = "token_bindings"
|
||||
__table_args__ = (
|
||||
Index("idx_token_bindings_hash", "token_hash"),
|
||||
Index("idx_token_bindings_ip", "bound_ip", postgresql_using="gist", postgresql_ops={"bound_ip": "inet_ops"}),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
|
||||
token_display: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
bound_ip: Mapped[str] = mapped_column(CIDR, nullable=False)
|
||||
status: Mapped[int] = mapped_column(
|
||||
SmallInteger,
|
||||
nullable=False,
|
||||
default=STATUS_ACTIVE,
|
||||
server_default=text("1"),
|
||||
)
|
||||
first_used_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
)
|
||||
last_used_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=func.now(),
|
||||
)
|
||||
111
app/proxy/handler.py
Normal file
111
app/proxy/handler.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from app.config import Settings
|
||||
from app.core.ip_utils import extract_client_ip
|
||||
from app.core.security import extract_bearer_token
|
||||
from app.dependencies import get_alert_service, get_binding_service, get_settings
|
||||
from app.services.alert_service import AlertService
|
||||
from app.services.binding_service import BindingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
CONTENT_LENGTH_HEADER = "content-length"
|
||||
|
||||
|
||||
def build_upstream_headers(request: Request, downstream_url: str) -> list[tuple[str, str]]:
|
||||
downstream_host = urlsplit(downstream_url).netloc
|
||||
headers: list[tuple[str, str]] = []
|
||||
for header_name, header_value in request.headers.items():
|
||||
if header_name.lower() == "host":
|
||||
continue
|
||||
headers.append((header_name, header_value))
|
||||
headers.append(("host", downstream_host))
|
||||
return headers
|
||||
|
||||
|
||||
def build_upstream_url(settings: Settings, request: Request) -> str:
|
||||
return f"{settings.downstream_url}{request.url.path}"
|
||||
|
||||
|
||||
def apply_downstream_headers(response: StreamingResponse, upstream_response: httpx.Response) -> None:
|
||||
for header_name, header_value in upstream_response.headers.multi_items():
|
||||
if header_name.lower() == CONTENT_LENGTH_HEADER:
|
||||
continue
|
||||
response.headers.append(header_name, header_value)
|
||||
|
||||
|
||||
@router.api_route("/", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"], include_in_schema=False)
|
||||
@router.api_route(
|
||||
"/{path:path}",
|
||||
methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"],
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def reverse_proxy(
|
||||
request: Request,
|
||||
path: str = "",
|
||||
settings: Settings = Depends(get_settings),
|
||||
binding_service: BindingService = Depends(get_binding_service),
|
||||
alert_service: AlertService = Depends(get_alert_service),
|
||||
):
|
||||
client_ip = extract_client_ip(request, settings)
|
||||
token = extract_bearer_token(request.headers.get("authorization"))
|
||||
|
||||
if token:
|
||||
binding_result = await binding_service.evaluate_token_binding(token, client_ip)
|
||||
if binding_result.allowed:
|
||||
await binding_service.increment_request_metric("allowed")
|
||||
else:
|
||||
await binding_service.increment_request_metric("intercepted" if binding_result.should_alert else None)
|
||||
if binding_result.should_alert and binding_result.token_hash and binding_result.token_display and binding_result.bound_ip:
|
||||
await alert_service.handle_intercept(
|
||||
token_hash=binding_result.token_hash,
|
||||
token_display=binding_result.token_display,
|
||||
bound_ip=binding_result.bound_ip,
|
||||
attempt_ip=client_ip,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=binding_result.status_code,
|
||||
content={"detail": binding_result.detail},
|
||||
)
|
||||
else:
|
||||
await binding_service.increment_request_metric("allowed")
|
||||
|
||||
downstream_client: httpx.AsyncClient = request.app.state.downstream_client
|
||||
upstream_url = build_upstream_url(settings, request)
|
||||
upstream_headers = build_upstream_headers(request, settings.downstream_url)
|
||||
|
||||
try:
|
||||
upstream_request = downstream_client.build_request(
|
||||
request.method,
|
||||
upstream_url,
|
||||
params=request.query_params.multi_items(),
|
||||
headers=upstream_headers,
|
||||
content=request.stream(),
|
||||
)
|
||||
upstream_response = await downstream_client.send(upstream_request, stream=True)
|
||||
except httpx.HTTPError as exc:
|
||||
logger.exception("Failed to reach downstream service.")
|
||||
return JSONResponse(status_code=502, content={"detail": f"Downstream request failed: {exc!s}"})
|
||||
|
||||
async def stream_response():
|
||||
try:
|
||||
async for chunk in upstream_response.aiter_raw():
|
||||
yield chunk
|
||||
finally:
|
||||
await upstream_response.aclose()
|
||||
|
||||
response = StreamingResponse(
|
||||
stream_response(),
|
||||
status_code=upstream_response.status_code,
|
||||
media_type=None,
|
||||
)
|
||||
apply_downstream_headers(response, upstream_response)
|
||||
return response
|
||||
13
app/schemas/auth.py
Normal file
13
app/schemas/auth.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
password: str = Field(min_length=1, max_length=256)
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
42
app/schemas/binding.py
Normal file
42
app/schemas/binding.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class BindingItem(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
token_display: str
|
||||
bound_ip: str
|
||||
status: int
|
||||
status_label: str
|
||||
first_used_at: datetime
|
||||
last_used_at: datetime
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class BindingListResponse(BaseModel):
|
||||
items: list[BindingItem]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
|
||||
class BindingActionRequest(BaseModel):
|
||||
id: int = Field(gt=0)
|
||||
|
||||
|
||||
class BindingIPUpdateRequest(BaseModel):
|
||||
id: int = Field(gt=0)
|
||||
bound_ip: str = Field(min_length=3, max_length=64)
|
||||
|
||||
@field_validator("bound_ip")
|
||||
@classmethod
|
||||
def validate_bound_ip(cls, value: str) -> str:
|
||||
from ipaddress import ip_network
|
||||
|
||||
ip_network(value, strict=False)
|
||||
return value
|
||||
23
app/schemas/log.py
Normal file
23
app/schemas/log.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class InterceptLogItem(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
token_display: str
|
||||
bound_ip: str
|
||||
attempt_ip: str
|
||||
alerted: bool
|
||||
intercepted_at: datetime
|
||||
|
||||
|
||||
class LogListResponse(BaseModel):
|
||||
items: list[InterceptLogItem]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
123
app/services/alert_service.py
Normal file
123
app/services/alert_service.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
import httpx
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.config import RuntimeSettings, Settings
|
||||
from app.models.intercept_log import InterceptLog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AlertService:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
redis: Redis | None,
|
||||
http_client: httpx.AsyncClient,
|
||||
runtime_settings_getter: Callable[[], RuntimeSettings],
|
||||
) -> None:
|
||||
self.settings = settings
|
||||
self.session_factory = session_factory
|
||||
self.redis = redis
|
||||
self.http_client = http_client
|
||||
self.runtime_settings_getter = runtime_settings_getter
|
||||
|
||||
def alert_key(self, token_hash: str) -> str:
|
||||
return f"sentinel:alert:{token_hash}"
|
||||
|
||||
async def handle_intercept(
|
||||
self,
|
||||
token_hash: str,
|
||||
token_display: str,
|
||||
bound_ip: str,
|
||||
attempt_ip: str,
|
||||
) -> None:
|
||||
await self._write_intercept_log(token_hash, token_display, bound_ip, attempt_ip)
|
||||
|
||||
runtime_settings = self.runtime_settings_getter()
|
||||
if self.redis is None:
|
||||
logger.warning("Redis is unavailable. Intercept alert counters are disabled.")
|
||||
return
|
||||
|
||||
try:
|
||||
async with self.redis.pipeline(transaction=True) as pipeline:
|
||||
pipeline.incr(self.alert_key(token_hash))
|
||||
pipeline.expire(self.alert_key(token_hash), runtime_settings.alert_threshold_seconds)
|
||||
result = await pipeline.execute()
|
||||
except Exception:
|
||||
logger.warning("Failed to update intercept alert counter.", extra={"token_hash": token_hash})
|
||||
return
|
||||
|
||||
count = int(result[0])
|
||||
if count < runtime_settings.alert_threshold_count:
|
||||
return
|
||||
|
||||
payload = {
|
||||
"token_display": token_display,
|
||||
"attempt_ip": attempt_ip,
|
||||
"bound_ip": bound_ip,
|
||||
"count": count,
|
||||
}
|
||||
if runtime_settings.alert_webhook_url:
|
||||
try:
|
||||
await self.http_client.post(runtime_settings.alert_webhook_url, json=payload)
|
||||
except httpx.HTTPError:
|
||||
logger.exception("Failed to deliver alert webhook.", extra={"token_hash": token_hash})
|
||||
|
||||
try:
|
||||
await self.redis.delete(self.alert_key(token_hash))
|
||||
except Exception:
|
||||
logger.warning("Failed to clear intercept alert counter.", extra={"token_hash": token_hash})
|
||||
|
||||
await self._mark_alerted_records(token_hash, runtime_settings.alert_threshold_seconds)
|
||||
|
||||
async def _write_intercept_log(
|
||||
self,
|
||||
token_hash: str,
|
||||
token_display: str,
|
||||
bound_ip: str,
|
||||
attempt_ip: str,
|
||||
) -> None:
|
||||
async with self.session_factory() as session:
|
||||
try:
|
||||
session.add(
|
||||
InterceptLog(
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=bound_ip,
|
||||
attempt_ip=attempt_ip,
|
||||
alerted=False,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to write intercept log.", extra={"token_hash": token_hash})
|
||||
|
||||
async def _mark_alerted_records(self, token_hash: str, threshold_seconds: int) -> None:
|
||||
statement = text(
|
||||
"""
|
||||
UPDATE intercept_logs
|
||||
SET alerted = TRUE
|
||||
WHERE token_hash = :token_hash
|
||||
AND intercepted_at >= NOW() - (:threshold_seconds || ' seconds')::interval
|
||||
"""
|
||||
)
|
||||
async with self.session_factory() as session:
|
||||
try:
|
||||
await session.execute(
|
||||
statement,
|
||||
{"token_hash": token_hash, "threshold_seconds": threshold_seconds},
|
||||
)
|
||||
await session.commit()
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to mark intercept logs as alerted.", extra={"token_hash": token_hash})
|
||||
84
app/services/archive_service.py
Normal file
84
app/services/archive_service.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Callable
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.config import RuntimeSettings, Settings
|
||||
from app.models.token_binding import TokenBinding
|
||||
from app.services.binding_service import BindingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ArchiveService:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
binding_service: BindingService,
|
||||
runtime_settings_getter: Callable[[], RuntimeSettings],
|
||||
) -> None:
|
||||
self.settings = settings
|
||||
self.session_factory = session_factory
|
||||
self.binding_service = binding_service
|
||||
self.runtime_settings_getter = runtime_settings_getter
|
||||
self.scheduler = AsyncIOScheduler(timezone="UTC")
|
||||
|
||||
async def start(self) -> None:
|
||||
if self.scheduler.running:
|
||||
return
|
||||
self.scheduler.add_job(
|
||||
self.archive_inactive_bindings,
|
||||
trigger="interval",
|
||||
minutes=self.settings.archive_job_interval_minutes,
|
||||
id="archive-inactive-bindings",
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
coalesce=True,
|
||||
)
|
||||
self.scheduler.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self.scheduler.running:
|
||||
self.scheduler.shutdown(wait=False)
|
||||
|
||||
async def archive_inactive_bindings(self) -> int:
|
||||
runtime_settings = self.runtime_settings_getter()
|
||||
cutoff = datetime.now(UTC) - timedelta(days=runtime_settings.archive_days)
|
||||
total_archived = 0
|
||||
|
||||
while True:
|
||||
async with self.session_factory() as session:
|
||||
try:
|
||||
result = await session.execute(
|
||||
select(TokenBinding.token_hash)
|
||||
.where(TokenBinding.last_used_at < cutoff)
|
||||
.order_by(TokenBinding.last_used_at.asc())
|
||||
.limit(self.settings.archive_batch_size)
|
||||
)
|
||||
token_hashes = list(result.scalars())
|
||||
if not token_hashes:
|
||||
break
|
||||
|
||||
await session.execute(delete(TokenBinding).where(TokenBinding.token_hash.in_(token_hashes)))
|
||||
await session.commit()
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to archive inactive bindings.")
|
||||
break
|
||||
|
||||
await self.binding_service.invalidate_many(token_hashes)
|
||||
total_archived += len(token_hashes)
|
||||
|
||||
if len(token_hashes) < self.settings.archive_batch_size:
|
||||
break
|
||||
|
||||
if total_archived:
|
||||
logger.info("Archived inactive bindings.", extra={"count": total_archived})
|
||||
return total_archived
|
||||
464
app/services/binding_service.py
Normal file
464
app/services/binding_service.py
Normal file
@@ -0,0 +1,464 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, date, timedelta
|
||||
from typing import Callable
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import func, select, text, update
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.config import RuntimeSettings, Settings
|
||||
from app.core.ip_utils import is_ip_in_network
|
||||
from app.core.security import hash_token, mask_token
|
||||
from app.models.token_binding import STATUS_ACTIVE, STATUS_BANNED, TokenBinding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class BindingRecord:
|
||||
id: int
|
||||
token_hash: str
|
||||
token_display: str
|
||||
bound_ip: str
|
||||
status: int
|
||||
ip_matched: bool
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class BindingCheckResult:
|
||||
allowed: bool
|
||||
status_code: int
|
||||
detail: str
|
||||
token_hash: str | None = None
|
||||
token_display: str | None = None
|
||||
bound_ip: str | None = None
|
||||
should_alert: bool = False
|
||||
newly_bound: bool = False
|
||||
|
||||
|
||||
class InMemoryRateLimiter:
|
||||
def __init__(self, max_per_second: int) -> None:
|
||||
self.max_per_second = max(1, max_per_second)
|
||||
self._window = int(time.monotonic())
|
||||
self._count = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def allow(self) -> bool:
|
||||
async with self._lock:
|
||||
current_window = int(time.monotonic())
|
||||
if current_window != self._window:
|
||||
self._window = current_window
|
||||
self._count = 0
|
||||
if self._count >= self.max_per_second:
|
||||
return False
|
||||
self._count += 1
|
||||
return True
|
||||
|
||||
|
||||
class BindingService:
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
redis: Redis | None,
|
||||
runtime_settings_getter: Callable[[], RuntimeSettings],
|
||||
) -> None:
|
||||
self.settings = settings
|
||||
self.session_factory = session_factory
|
||||
self.redis = redis
|
||||
self.runtime_settings_getter = runtime_settings_getter
|
||||
self.last_used_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=settings.last_used_queue_size)
|
||||
self._flush_task: asyncio.Task[None] | None = None
|
||||
self._stop_event = asyncio.Event()
|
||||
self._redis_degraded_limiter = InMemoryRateLimiter(settings.downstream_max_connections // 2)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._flush_task is None:
|
||||
self._stop_event.clear()
|
||||
self._flush_task = asyncio.create_task(self._flush_loop(), name="binding-last-used-flush")
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._stop_event.set()
|
||||
if self._flush_task is not None:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._flush_task = None
|
||||
await self.flush_last_used_updates()
|
||||
|
||||
def status_label(self, status_code: int) -> str:
|
||||
return "Active" if status_code == STATUS_ACTIVE else "Banned"
|
||||
|
||||
def cache_key(self, token_hash: str) -> str:
|
||||
return f"sentinel:binding:{token_hash}"
|
||||
|
||||
def metrics_key(self, target_date: date) -> str:
|
||||
return f"sentinel:metrics:{target_date.isoformat()}"
|
||||
|
||||
async def evaluate_token_binding(self, token: str, client_ip: str) -> BindingCheckResult:
|
||||
token_hash = hash_token(token, self.settings.sentinel_hmac_secret)
|
||||
token_display = mask_token(token)
|
||||
|
||||
cache_hit, cache_available = await self._load_binding_from_cache(token_hash)
|
||||
if cache_hit is not None:
|
||||
if cache_hit.status == STATUS_BANNED:
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Token is banned.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=cache_hit.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
if is_ip_in_network(client_ip, cache_hit.bound_ip):
|
||||
await self._touch_cache(token_hash)
|
||||
self.record_last_used(token_hash)
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail="Allowed from cache.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=cache_hit.bound_ip,
|
||||
)
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Client IP does not match the bound CIDR.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=cache_hit.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
|
||||
if not cache_available:
|
||||
logger.warning("Redis is unavailable. Falling back to PostgreSQL for token binding.")
|
||||
if not await self._redis_degraded_limiter.allow():
|
||||
logger.warning("Redis degraded limiter rejected a request during PostgreSQL fallback.")
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=429,
|
||||
detail="Redis degraded mode rate limit reached.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
)
|
||||
|
||||
try:
|
||||
record = await self._load_binding_from_db(token_hash, client_ip)
|
||||
except SQLAlchemyError:
|
||||
return self._handle_backend_failure(token_hash, token_display)
|
||||
|
||||
if record is not None:
|
||||
await self.sync_binding_cache(record.token_hash, record.bound_ip, record.status)
|
||||
if record.status == STATUS_BANNED:
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Token is banned.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=record.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
if record.ip_matched:
|
||||
self.record_last_used(token_hash)
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail="Allowed from PostgreSQL.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=record.bound_ip,
|
||||
)
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Client IP does not match the bound CIDR.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=record.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
|
||||
try:
|
||||
created = await self._create_binding(token_hash, token_display, client_ip)
|
||||
except SQLAlchemyError:
|
||||
return self._handle_backend_failure(token_hash, token_display)
|
||||
|
||||
if created is None:
|
||||
try:
|
||||
existing = await self._load_binding_from_db(token_hash, client_ip)
|
||||
except SQLAlchemyError:
|
||||
return self._handle_backend_failure(token_hash, token_display)
|
||||
if existing is None:
|
||||
return self._handle_backend_failure(token_hash, token_display)
|
||||
await self.sync_binding_cache(existing.token_hash, existing.bound_ip, existing.status)
|
||||
if existing.status == STATUS_BANNED:
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Token is banned.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=existing.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
if existing.ip_matched:
|
||||
self.record_last_used(token_hash)
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail="Allowed after concurrent bind resolution.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=existing.bound_ip,
|
||||
)
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Client IP does not match the bound CIDR.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=existing.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
|
||||
await self.sync_binding_cache(created.token_hash, created.bound_ip, created.status)
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail="First-use bind created.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=created.bound_ip,
|
||||
newly_bound=True,
|
||||
)
|
||||
|
||||
async def sync_binding_cache(self, token_hash: str, bound_ip: str, status_code: int) -> None:
|
||||
if self.redis is None:
|
||||
return
|
||||
payload = json.dumps({"bound_ip": bound_ip, "status": status_code})
|
||||
try:
|
||||
await self.redis.set(self.cache_key(token_hash), payload, ex=self.settings.redis_binding_ttl_seconds)
|
||||
except Exception:
|
||||
logger.warning("Failed to write binding cache.", extra={"token_hash": token_hash})
|
||||
|
||||
async def invalidate_binding_cache(self, token_hash: str) -> None:
|
||||
if self.redis is None:
|
||||
return
|
||||
try:
|
||||
await self.redis.delete(self.cache_key(token_hash))
|
||||
except Exception:
|
||||
logger.warning("Failed to delete binding cache.", extra={"token_hash": token_hash})
|
||||
|
||||
async def invalidate_many(self, token_hashes: list[str]) -> None:
|
||||
if self.redis is None or not token_hashes:
|
||||
return
|
||||
keys = [self.cache_key(item) for item in token_hashes]
|
||||
try:
|
||||
await self.redis.delete(*keys)
|
||||
except Exception:
|
||||
logger.warning("Failed to delete multiple binding cache keys.", extra={"count": len(keys)})
|
||||
|
||||
def record_last_used(self, token_hash: str) -> None:
|
||||
try:
|
||||
self.last_used_queue.put_nowait(token_hash)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("last_used queue is full; dropping update.", extra={"token_hash": token_hash})
|
||||
|
||||
async def flush_last_used_updates(self) -> None:
|
||||
token_hashes: set[str] = set()
|
||||
while True:
|
||||
try:
|
||||
token_hashes.add(self.last_used_queue.get_nowait())
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
if not token_hashes:
|
||||
return
|
||||
|
||||
async with self.session_factory() as session:
|
||||
try:
|
||||
stmt = (
|
||||
update(TokenBinding)
|
||||
.where(TokenBinding.token_hash.in_(token_hashes))
|
||||
.values(last_used_at=func.now())
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
logger.exception("Failed to flush last_used_at updates.", extra={"count": len(token_hashes)})
|
||||
|
||||
async def increment_request_metric(self, outcome: str | None) -> None:
|
||||
if self.redis is None:
|
||||
return
|
||||
key = self.metrics_key(date.today())
|
||||
ttl = self.settings.metrics_ttl_days * 86400
|
||||
try:
|
||||
async with self.redis.pipeline(transaction=True) as pipeline:
|
||||
pipeline.hincrby(key, "total", 1)
|
||||
if outcome in {"allowed", "intercepted"}:
|
||||
pipeline.hincrby(key, outcome, 1)
|
||||
pipeline.expire(key, ttl)
|
||||
await pipeline.execute()
|
||||
except Exception:
|
||||
logger.warning("Failed to increment request metrics.", extra={"outcome": outcome})
|
||||
|
||||
async def get_metrics_window(self, days: int = 7) -> list[dict[str, int | str]]:
|
||||
if self.redis is None:
|
||||
return [
|
||||
{"date": (date.today() - timedelta(days=offset)).isoformat(), "allowed": 0, "intercepted": 0, "total": 0}
|
||||
for offset in range(days - 1, -1, -1)
|
||||
]
|
||||
|
||||
series: list[dict[str, int | str]] = []
|
||||
for offset in range(days - 1, -1, -1):
|
||||
target = date.today() - timedelta(days=offset)
|
||||
raw = await self.redis.hgetall(self.metrics_key(target))
|
||||
series.append(
|
||||
{
|
||||
"date": target.isoformat(),
|
||||
"allowed": int(raw.get("allowed", 0)),
|
||||
"intercepted": int(raw.get("intercepted", 0)),
|
||||
"total": int(raw.get("total", 0)),
|
||||
}
|
||||
)
|
||||
return series
|
||||
|
||||
async def _load_binding_from_cache(self, token_hash: str) -> tuple[BindingRecord | None, bool]:
|
||||
if self.redis is None:
|
||||
return None, False
|
||||
try:
|
||||
raw = await self.redis.get(self.cache_key(token_hash))
|
||||
except Exception:
|
||||
logger.warning("Failed to read binding cache.", extra={"token_hash": token_hash})
|
||||
return None, False
|
||||
if raw is None:
|
||||
return None, True
|
||||
|
||||
data = json.loads(raw)
|
||||
return (
|
||||
BindingRecord(
|
||||
id=0,
|
||||
token_hash=token_hash,
|
||||
token_display="",
|
||||
bound_ip=data["bound_ip"],
|
||||
status=int(data["status"]),
|
||||
ip_matched=False,
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
async def _touch_cache(self, token_hash: str) -> None:
|
||||
if self.redis is None:
|
||||
return
|
||||
try:
|
||||
await self.redis.expire(self.cache_key(token_hash), self.settings.redis_binding_ttl_seconds)
|
||||
except Exception:
|
||||
logger.warning("Failed to extend binding cache TTL.", extra={"token_hash": token_hash})
|
||||
|
||||
async def _load_binding_from_db(self, token_hash: str, client_ip: str) -> BindingRecord | None:
|
||||
query = text(
|
||||
"""
|
||||
SELECT
|
||||
id,
|
||||
token_hash,
|
||||
token_display,
|
||||
bound_ip::text AS bound_ip,
|
||||
status,
|
||||
CAST(:client_ip AS inet) << bound_ip AS ip_matched
|
||||
FROM token_bindings
|
||||
WHERE token_hash = :token_hash
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(query, {"token_hash": token_hash, "client_ip": client_ip})
|
||||
row = result.mappings().first()
|
||||
if row is None:
|
||||
return None
|
||||
return BindingRecord(
|
||||
id=int(row["id"]),
|
||||
token_hash=str(row["token_hash"]),
|
||||
token_display=str(row["token_display"]),
|
||||
bound_ip=str(row["bound_ip"]),
|
||||
status=int(row["status"]),
|
||||
ip_matched=bool(row["ip_matched"]),
|
||||
)
|
||||
|
||||
async def _create_binding(self, token_hash: str, token_display: str, client_ip: str) -> BindingRecord | None:
|
||||
statement = text(
|
||||
"""
|
||||
INSERT INTO token_bindings (token_hash, token_display, bound_ip, status)
|
||||
VALUES (:token_hash, :token_display, CAST(:bound_ip AS cidr), :status)
|
||||
ON CONFLICT (token_hash) DO NOTHING
|
||||
RETURNING id, token_hash, token_display, bound_ip::text AS bound_ip, status
|
||||
"""
|
||||
)
|
||||
async with self.session_factory() as session:
|
||||
try:
|
||||
result = await session.execute(
|
||||
statement,
|
||||
{
|
||||
"token_hash": token_hash,
|
||||
"token_display": token_display,
|
||||
"bound_ip": client_ip,
|
||||
"status": STATUS_ACTIVE,
|
||||
},
|
||||
)
|
||||
row = result.mappings().first()
|
||||
await session.commit()
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise
|
||||
if row is None:
|
||||
return None
|
||||
return BindingRecord(
|
||||
id=int(row["id"]),
|
||||
token_hash=str(row["token_hash"]),
|
||||
token_display=str(row["token_display"]),
|
||||
bound_ip=str(row["bound_ip"]),
|
||||
status=int(row["status"]),
|
||||
ip_matched=True,
|
||||
)
|
||||
|
||||
def _handle_backend_failure(self, token_hash: str, token_display: str) -> BindingCheckResult:
|
||||
runtime_settings = self.runtime_settings_getter()
|
||||
logger.exception(
|
||||
"Binding storage backend failed.",
|
||||
extra={"failsafe_mode": runtime_settings.failsafe_mode, "token_hash": token_hash},
|
||||
)
|
||||
if runtime_settings.failsafe_mode == "open":
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail="Allowed by failsafe mode.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
)
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=503,
|
||||
detail="Binding backend unavailable and failsafe mode is closed.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
)
|
||||
|
||||
async def _flush_loop(self) -> None:
|
||||
try:
|
||||
while not self._stop_event.is_set():
|
||||
await asyncio.sleep(self.settings.last_used_flush_interval_seconds)
|
||||
await self.flush_last_used_updates()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
Reference in New Issue
Block a user