465 lines
17 KiB
Python
465 lines
17 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
import time
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from datetime import UTC, date, timedelta
|
||
|
|
from typing import Callable
|
||
|
|
|
||
|
|
from redis.asyncio import Redis
|
||
|
|
from sqlalchemy import func, select, text, update
|
||
|
|
from sqlalchemy.exc import SQLAlchemyError
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||
|
|
|
||
|
|
from app.config import RuntimeSettings, Settings
|
||
|
|
from app.core.ip_utils import is_ip_in_network
|
||
|
|
from app.core.security import hash_token, mask_token
|
||
|
|
from app.models.token_binding import STATUS_ACTIVE, STATUS_BANNED, TokenBinding
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(slots=True)
|
||
|
|
class BindingRecord:
|
||
|
|
id: int
|
||
|
|
token_hash: str
|
||
|
|
token_display: str
|
||
|
|
bound_ip: str
|
||
|
|
status: int
|
||
|
|
ip_matched: bool
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(slots=True)
|
||
|
|
class BindingCheckResult:
|
||
|
|
allowed: bool
|
||
|
|
status_code: int
|
||
|
|
detail: str
|
||
|
|
token_hash: str | None = None
|
||
|
|
token_display: str | None = None
|
||
|
|
bound_ip: str | None = None
|
||
|
|
should_alert: bool = False
|
||
|
|
newly_bound: bool = False
|
||
|
|
|
||
|
|
|
||
|
|
class InMemoryRateLimiter:
|
||
|
|
def __init__(self, max_per_second: int) -> None:
|
||
|
|
self.max_per_second = max(1, max_per_second)
|
||
|
|
self._window = int(time.monotonic())
|
||
|
|
self._count = 0
|
||
|
|
self._lock = asyncio.Lock()
|
||
|
|
|
||
|
|
async def allow(self) -> bool:
|
||
|
|
async with self._lock:
|
||
|
|
current_window = int(time.monotonic())
|
||
|
|
if current_window != self._window:
|
||
|
|
self._window = current_window
|
||
|
|
self._count = 0
|
||
|
|
if self._count >= self.max_per_second:
|
||
|
|
return False
|
||
|
|
self._count += 1
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
class BindingService:
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
settings: Settings,
|
||
|
|
session_factory: async_sessionmaker[AsyncSession],
|
||
|
|
redis: Redis | None,
|
||
|
|
runtime_settings_getter: Callable[[], RuntimeSettings],
|
||
|
|
) -> None:
|
||
|
|
self.settings = settings
|
||
|
|
self.session_factory = session_factory
|
||
|
|
self.redis = redis
|
||
|
|
self.runtime_settings_getter = runtime_settings_getter
|
||
|
|
self.last_used_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=settings.last_used_queue_size)
|
||
|
|
self._flush_task: asyncio.Task[None] | None = None
|
||
|
|
self._stop_event = asyncio.Event()
|
||
|
|
self._redis_degraded_limiter = InMemoryRateLimiter(settings.downstream_max_connections // 2)
|
||
|
|
|
||
|
|
async def start(self) -> None:
|
||
|
|
if self._flush_task is None:
|
||
|
|
self._stop_event.clear()
|
||
|
|
self._flush_task = asyncio.create_task(self._flush_loop(), name="binding-last-used-flush")
|
||
|
|
|
||
|
|
async def stop(self) -> None:
|
||
|
|
self._stop_event.set()
|
||
|
|
if self._flush_task is not None:
|
||
|
|
self._flush_task.cancel()
|
||
|
|
try:
|
||
|
|
await self._flush_task
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
pass
|
||
|
|
self._flush_task = None
|
||
|
|
await self.flush_last_used_updates()
|
||
|
|
|
||
|
|
def status_label(self, status_code: int) -> str:
|
||
|
|
return "Active" if status_code == STATUS_ACTIVE else "Banned"
|
||
|
|
|
||
|
|
def cache_key(self, token_hash: str) -> str:
|
||
|
|
return f"sentinel:binding:{token_hash}"
|
||
|
|
|
||
|
|
def metrics_key(self, target_date: date) -> str:
|
||
|
|
return f"sentinel:metrics:{target_date.isoformat()}"
|
||
|
|
|
||
|
|
async def evaluate_token_binding(self, token: str, client_ip: str) -> BindingCheckResult:
|
||
|
|
token_hash = hash_token(token, self.settings.sentinel_hmac_secret)
|
||
|
|
token_display = mask_token(token)
|
||
|
|
|
||
|
|
cache_hit, cache_available = await self._load_binding_from_cache(token_hash)
|
||
|
|
if cache_hit is not None:
|
||
|
|
if cache_hit.status == STATUS_BANNED:
|
||
|
|
return BindingCheckResult(
|
||
|
|
allowed=False,
|
||
|
|
status_code=403,
|
||
|
|
detail="Token is banned.",
|
||
|
|
token_hash=token_hash,
|
||
|
|
token_display=token_display,
|
||
|
|
bound_ip=cache_hit.bound_ip,
|
||
|
|
should_alert=True,
|
||
|
|
)
|
||
|
|
if is_ip_in_network(client_ip, cache_hit.bound_ip):
|
||
|
|
await self._touch_cache(token_hash)
|
||
|
|
self.record_last_used(token_hash)
|
||
|
|
return BindingCheckResult(
|
||
|
|
allowed=True,
|
||
|
|
status_code=200,
|
||
|
|
detail="Allowed from cache.",
|
||
|
|
token_hash=token_hash,
|
||
|
|
token_display=token_display,
|
||
|
|
bound_ip=cache_hit.bound_ip,
|
||
|
|
)
|
||
|
|
return BindingCheckResult(
|
||
|
|
allowed=False,
|
||
|
|
status_code=403,
|
||
|
|
detail="Client IP does not match the bound CIDR.",
|
||
|
|
token_hash=token_hash,
|
||
|
|
token_display=token_display,
|
||
|
|
bound_ip=cache_hit.bound_ip,
|
||
|
|
should_alert=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
if not cache_available:
|
||
|
|
logger.warning("Redis is unavailable. Falling back to PostgreSQL for token binding.")
|
||
|
|
if not await self._redis_degraded_limiter.allow():
|
||
|
|
logger.warning("Redis degraded limiter rejected a request during PostgreSQL fallback.")
|
||
|
|
return BindingCheckResult(
|
||
|
|
allowed=False,
|
||
|
|
status_code=429,
|
||
|
|
detail="Redis degraded mode rate limit reached.",
|
||
|
|
token_hash=token_hash,
|
||
|
|
token_display=token_display,
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
record = await self._load_binding_from_db(token_hash, client_ip)
|
||
|
|
except SQLAlchemyError:
|
||
|
|
return self._handle_backend_failure(token_hash, token_display)
|
||
|
|
|
||
|
|
if record is not None:
|
||
|
|
await self.sync_binding_cache(record.token_hash, record.bound_ip, record.status)
|
||
|
|
if record.status == STATUS_BANNED:
|
||
|
|
return BindingCheckResult(
|
||
|
|
allowed=False,
|
||
|
|
status_code=403,
|
||
|
|
detail="Token is banned.",
|
||
|
|
token_hash=token_hash,
|
||
|
|
token_display=token_display,
|
||
|
|
bound_ip=record.bound_ip,
|
||
|
|
should_alert=True,
|
||
|
|
)
|
||
|
|
if record.ip_matched:
|
||
|
|
self.record_last_used(token_hash)
|
||
|
|
return BindingCheckResult(
|
||
|
|
allowed=True,
|
||
|
|
status_code=200,
|
||
|
|
detail="Allowed from PostgreSQL.",
|
||
|
|
token_hash=token_hash,
|
||
|
|
token_display=token_display,
|
||
|
|
bound_ip=record.bound_ip,
|
||
|
|
)
|
||
|
|
return BindingCheckResult(
|
||
|
|
allowed=False,
|
||
|
|
status_code=403,
|
||
|
|
detail="Client IP does not match the bound CIDR.",
|
||
|
|
token_hash=token_hash,
|
||
|
|
token_display=token_display,
|
||
|
|
bound_ip=record.bound_ip,
|
||
|
|
should_alert=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
created = await self._create_binding(token_hash, token_display, client_ip)
|
||
|
|
except SQLAlchemyError:
|
||
|
|
return self._handle_backend_failure(token_hash, token_display)
|
||
|
|
|
||
|
|
if created is None:
|
||
|
|
try:
|
||
|
|
existing = await self._load_binding_from_db(token_hash, client_ip)
|
||
|
|
except SQLAlchemyError:
|
||
|
|
return self._handle_backend_failure(token_hash, token_display)
|
||
|
|
if existing is None:
|
||
|
|
return self._handle_backend_failure(token_hash, token_display)
|
||
|
|
await self.sync_binding_cache(existing.token_hash, existing.bound_ip, existing.status)
|
||
|
|
if existing.status == STATUS_BANNED:
|
||
|
|
return BindingCheckResult(
|
||
|
|
allowed=False,
|
||
|
|
status_code=403,
|
||
|
|
detail="Token is banned.",
|
||
|
|
token_hash=token_hash,
|
||
|
|
token_display=token_display,
|
||
|
|
bound_ip=existing.bound_ip,
|
||
|
|
should_alert=True,
|
||
|
|
)
|
||
|
|
if existing.ip_matched:
|
||
|
|
self.record_last_used(token_hash)
|
||
|
|
return BindingCheckResult(
|
||
|
|
allowed=True,
|
||
|
|
status_code=200,
|
||
|
|
detail="Allowed after concurrent bind resolution.",
|
||
|
|
token_hash=token_hash,
|
||
|
|
token_display=token_display,
|
||
|
|
bound_ip=existing.bound_ip,
|
||
|
|
)
|
||
|
|
return BindingCheckResult(
|
||
|
|
allowed=False,
|
||
|
|
status_code=403,
|
||
|
|
detail="Client IP does not match the bound CIDR.",
|
||
|
|
token_hash=token_hash,
|
||
|
|
token_display=token_display,
|
||
|
|
bound_ip=existing.bound_ip,
|
||
|
|
should_alert=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
await self.sync_binding_cache(created.token_hash, created.bound_ip, created.status)
|
||
|
|
return BindingCheckResult(
|
||
|
|
allowed=True,
|
||
|
|
status_code=200,
|
||
|
|
detail="First-use bind created.",
|
||
|
|
token_hash=token_hash,
|
||
|
|
token_display=token_display,
|
||
|
|
bound_ip=created.bound_ip,
|
||
|
|
newly_bound=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def sync_binding_cache(self, token_hash: str, bound_ip: str, status_code: int) -> None:
|
||
|
|
if self.redis is None:
|
||
|
|
return
|
||
|
|
payload = json.dumps({"bound_ip": bound_ip, "status": status_code})
|
||
|
|
try:
|
||
|
|
await self.redis.set(self.cache_key(token_hash), payload, ex=self.settings.redis_binding_ttl_seconds)
|
||
|
|
except Exception:
|
||
|
|
logger.warning("Failed to write binding cache.", extra={"token_hash": token_hash})
|
||
|
|
|
||
|
|
async def invalidate_binding_cache(self, token_hash: str) -> None:
|
||
|
|
if self.redis is None:
|
||
|
|
return
|
||
|
|
try:
|
||
|
|
await self.redis.delete(self.cache_key(token_hash))
|
||
|
|
except Exception:
|
||
|
|
logger.warning("Failed to delete binding cache.", extra={"token_hash": token_hash})
|
||
|
|
|
||
|
|
async def invalidate_many(self, token_hashes: list[str]) -> None:
|
||
|
|
if self.redis is None or not token_hashes:
|
||
|
|
return
|
||
|
|
keys = [self.cache_key(item) for item in token_hashes]
|
||
|
|
try:
|
||
|
|
await self.redis.delete(*keys)
|
||
|
|
except Exception:
|
||
|
|
logger.warning("Failed to delete multiple binding cache keys.", extra={"count": len(keys)})
|
||
|
|
|
||
|
|
def record_last_used(self, token_hash: str) -> None:
|
||
|
|
try:
|
||
|
|
self.last_used_queue.put_nowait(token_hash)
|
||
|
|
except asyncio.QueueFull:
|
||
|
|
logger.warning("last_used queue is full; dropping update.", extra={"token_hash": token_hash})
|
||
|
|
|
||
|
|
async def flush_last_used_updates(self) -> None:
|
||
|
|
token_hashes: set[str] = set()
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
token_hashes.add(self.last_used_queue.get_nowait())
|
||
|
|
except asyncio.QueueEmpty:
|
||
|
|
break
|
||
|
|
|
||
|
|
if not token_hashes:
|
||
|
|
return
|
||
|
|
|
||
|
|
async with self.session_factory() as session:
|
||
|
|
try:
|
||
|
|
stmt = (
|
||
|
|
update(TokenBinding)
|
||
|
|
.where(TokenBinding.token_hash.in_(token_hashes))
|
||
|
|
.values(last_used_at=func.now())
|
||
|
|
)
|
||
|
|
await session.execute(stmt)
|
||
|
|
await session.commit()
|
||
|
|
except SQLAlchemyError:
|
||
|
|
await session.rollback()
|
||
|
|
logger.exception("Failed to flush last_used_at updates.", extra={"count": len(token_hashes)})
|
||
|
|
|
||
|
|
async def increment_request_metric(self, outcome: str | None) -> None:
|
||
|
|
if self.redis is None:
|
||
|
|
return
|
||
|
|
key = self.metrics_key(date.today())
|
||
|
|
ttl = self.settings.metrics_ttl_days * 86400
|
||
|
|
try:
|
||
|
|
async with self.redis.pipeline(transaction=True) as pipeline:
|
||
|
|
pipeline.hincrby(key, "total", 1)
|
||
|
|
if outcome in {"allowed", "intercepted"}:
|
||
|
|
pipeline.hincrby(key, outcome, 1)
|
||
|
|
pipeline.expire(key, ttl)
|
||
|
|
await pipeline.execute()
|
||
|
|
except Exception:
|
||
|
|
logger.warning("Failed to increment request metrics.", extra={"outcome": outcome})
|
||
|
|
|
||
|
|
async def get_metrics_window(self, days: int = 7) -> list[dict[str, int | str]]:
|
||
|
|
if self.redis is None:
|
||
|
|
return [
|
||
|
|
{"date": (date.today() - timedelta(days=offset)).isoformat(), "allowed": 0, "intercepted": 0, "total": 0}
|
||
|
|
for offset in range(days - 1, -1, -1)
|
||
|
|
]
|
||
|
|
|
||
|
|
series: list[dict[str, int | str]] = []
|
||
|
|
for offset in range(days - 1, -1, -1):
|
||
|
|
target = date.today() - timedelta(days=offset)
|
||
|
|
raw = await self.redis.hgetall(self.metrics_key(target))
|
||
|
|
series.append(
|
||
|
|
{
|
||
|
|
"date": target.isoformat(),
|
||
|
|
"allowed": int(raw.get("allowed", 0)),
|
||
|
|
"intercepted": int(raw.get("intercepted", 0)),
|
||
|
|
"total": int(raw.get("total", 0)),
|
||
|
|
}
|
||
|
|
)
|
||
|
|
return series
|
||
|
|
|
||
|
|
async def _load_binding_from_cache(self, token_hash: str) -> tuple[BindingRecord | None, bool]:
|
||
|
|
if self.redis is None:
|
||
|
|
return None, False
|
||
|
|
try:
|
||
|
|
raw = await self.redis.get(self.cache_key(token_hash))
|
||
|
|
except Exception:
|
||
|
|
logger.warning("Failed to read binding cache.", extra={"token_hash": token_hash})
|
||
|
|
return None, False
|
||
|
|
if raw is None:
|
||
|
|
return None, True
|
||
|
|
|
||
|
|
data = json.loads(raw)
|
||
|
|
return (
|
||
|
|
BindingRecord(
|
||
|
|
id=0,
|
||
|
|
token_hash=token_hash,
|
||
|
|
token_display="",
|
||
|
|
bound_ip=data["bound_ip"],
|
||
|
|
status=int(data["status"]),
|
||
|
|
ip_matched=False,
|
||
|
|
),
|
||
|
|
True,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _touch_cache(self, token_hash: str) -> None:
|
||
|
|
if self.redis is None:
|
||
|
|
return
|
||
|
|
try:
|
||
|
|
await self.redis.expire(self.cache_key(token_hash), self.settings.redis_binding_ttl_seconds)
|
||
|
|
except Exception:
|
||
|
|
logger.warning("Failed to extend binding cache TTL.", extra={"token_hash": token_hash})
|
||
|
|
|
||
|
|
async def _load_binding_from_db(self, token_hash: str, client_ip: str) -> BindingRecord | None:
|
||
|
|
query = text(
|
||
|
|
"""
|
||
|
|
SELECT
|
||
|
|
id,
|
||
|
|
token_hash,
|
||
|
|
token_display,
|
||
|
|
bound_ip::text AS bound_ip,
|
||
|
|
status,
|
||
|
|
CAST(:client_ip AS inet) << bound_ip AS ip_matched
|
||
|
|
FROM token_bindings
|
||
|
|
WHERE token_hash = :token_hash
|
||
|
|
LIMIT 1
|
||
|
|
"""
|
||
|
|
)
|
||
|
|
async with self.session_factory() as session:
|
||
|
|
result = await session.execute(query, {"token_hash": token_hash, "client_ip": client_ip})
|
||
|
|
row = result.mappings().first()
|
||
|
|
if row is None:
|
||
|
|
return None
|
||
|
|
return BindingRecord(
|
||
|
|
id=int(row["id"]),
|
||
|
|
token_hash=str(row["token_hash"]),
|
||
|
|
token_display=str(row["token_display"]),
|
||
|
|
bound_ip=str(row["bound_ip"]),
|
||
|
|
status=int(row["status"]),
|
||
|
|
ip_matched=bool(row["ip_matched"]),
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _create_binding(self, token_hash: str, token_display: str, client_ip: str) -> BindingRecord | None:
|
||
|
|
statement = text(
|
||
|
|
"""
|
||
|
|
INSERT INTO token_bindings (token_hash, token_display, bound_ip, status)
|
||
|
|
VALUES (:token_hash, :token_display, CAST(:bound_ip AS cidr), :status)
|
||
|
|
ON CONFLICT (token_hash) DO NOTHING
|
||
|
|
RETURNING id, token_hash, token_display, bound_ip::text AS bound_ip, status
|
||
|
|
"""
|
||
|
|
)
|
||
|
|
async with self.session_factory() as session:
|
||
|
|
try:
|
||
|
|
result = await session.execute(
|
||
|
|
statement,
|
||
|
|
{
|
||
|
|
"token_hash": token_hash,
|
||
|
|
"token_display": token_display,
|
||
|
|
"bound_ip": client_ip,
|
||
|
|
"status": STATUS_ACTIVE,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
row = result.mappings().first()
|
||
|
|
await session.commit()
|
||
|
|
except SQLAlchemyError:
|
||
|
|
await session.rollback()
|
||
|
|
raise
|
||
|
|
if row is None:
|
||
|
|
return None
|
||
|
|
return BindingRecord(
|
||
|
|
id=int(row["id"]),
|
||
|
|
token_hash=str(row["token_hash"]),
|
||
|
|
token_display=str(row["token_display"]),
|
||
|
|
bound_ip=str(row["bound_ip"]),
|
||
|
|
status=int(row["status"]),
|
||
|
|
ip_matched=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _handle_backend_failure(self, token_hash: str, token_display: str) -> BindingCheckResult:
|
||
|
|
runtime_settings = self.runtime_settings_getter()
|
||
|
|
logger.exception(
|
||
|
|
"Binding storage backend failed.",
|
||
|
|
extra={"failsafe_mode": runtime_settings.failsafe_mode, "token_hash": token_hash},
|
||
|
|
)
|
||
|
|
if runtime_settings.failsafe_mode == "open":
|
||
|
|
return BindingCheckResult(
|
||
|
|
allowed=True,
|
||
|
|
status_code=200,
|
||
|
|
detail="Allowed by failsafe mode.",
|
||
|
|
token_hash=token_hash,
|
||
|
|
token_display=token_display,
|
||
|
|
)
|
||
|
|
return BindingCheckResult(
|
||
|
|
allowed=False,
|
||
|
|
status_code=503,
|
||
|
|
detail="Binding backend unavailable and failsafe mode is closed.",
|
||
|
|
token_hash=token_hash,
|
||
|
|
token_display=token_display,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _flush_loop(self) -> None:
|
||
|
|
try:
|
||
|
|
while not self._stop_event.is_set():
|
||
|
|
await asyncio.sleep(self.settings.last_used_flush_interval_seconds)
|
||
|
|
await self.flush_last_used_updates()
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
raise
|