463 lines
17 KiB
Python
463 lines
17 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass
|
|
from datetime import date, timedelta
|
|
from typing import Callable
|
|
|
|
from redis.asyncio import Redis
|
|
from sqlalchemy import func, select, update
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
|
|
from app.config import RuntimeSettings, Settings
|
|
from app.core.ip_utils import is_ip_in_network
|
|
from app.core.security import hash_token, mask_token
|
|
from app.models.token_binding import (
|
|
BINDING_MODE_ALL,
|
|
BINDING_MODE_MULTIPLE,
|
|
BINDING_MODE_SINGLE,
|
|
STATUS_ACTIVE,
|
|
STATUS_BANNED,
|
|
TokenBinding,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class BindingRecord:
|
|
id: int
|
|
token_hash: str
|
|
token_display: str
|
|
bound_ip: str
|
|
binding_mode: str
|
|
allowed_ips: list[str]
|
|
status: int
|
|
ip_matched: bool
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class BindingCheckResult:
|
|
allowed: bool
|
|
status_code: int
|
|
detail: str
|
|
token_hash: str | None = None
|
|
token_display: str | None = None
|
|
bound_ip: str | None = None
|
|
should_alert: bool = False
|
|
newly_bound: bool = False
|
|
|
|
|
|
class InMemoryRateLimiter:
|
|
def __init__(self, max_per_second: int) -> None:
|
|
self.max_per_second = max(1, max_per_second)
|
|
self._window = int(time.monotonic())
|
|
self._count = 0
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def allow(self) -> bool:
|
|
async with self._lock:
|
|
current_window = int(time.monotonic())
|
|
if current_window != self._window:
|
|
self._window = current_window
|
|
self._count = 0
|
|
if self._count >= self.max_per_second:
|
|
return False
|
|
self._count += 1
|
|
return True
|
|
|
|
|
|
class BindingService:
|
|
def __init__(
|
|
self,
|
|
settings: Settings,
|
|
session_factory: async_sessionmaker[AsyncSession],
|
|
redis: Redis | None,
|
|
runtime_settings_getter: Callable[[], RuntimeSettings],
|
|
) -> None:
|
|
self.settings = settings
|
|
self.session_factory = session_factory
|
|
self.redis = redis
|
|
self.runtime_settings_getter = runtime_settings_getter
|
|
self.last_used_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=settings.last_used_queue_size)
|
|
self._flush_task: asyncio.Task[None] | None = None
|
|
self._stop_event = asyncio.Event()
|
|
self._redis_degraded_limiter = InMemoryRateLimiter(settings.downstream_max_connections // 2)
|
|
|
|
async def start(self) -> None:
|
|
if self._flush_task is None:
|
|
self._stop_event.clear()
|
|
self._flush_task = asyncio.create_task(self._flush_loop(), name="binding-last-used-flush")
|
|
|
|
async def stop(self) -> None:
|
|
self._stop_event.set()
|
|
if self._flush_task is not None:
|
|
self._flush_task.cancel()
|
|
try:
|
|
await self._flush_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._flush_task = None
|
|
await self.flush_last_used_updates()
|
|
|
|
def status_label(self, status_code: int) -> str:
|
|
return "Active" if status_code == STATUS_ACTIVE else "Banned"
|
|
|
|
def cache_key(self, token_hash: str) -> str:
|
|
return f"sentinel:binding:{token_hash}"
|
|
|
|
def metrics_key(self, target_date: date) -> str:
|
|
return f"sentinel:metrics:{target_date.isoformat()}"
|
|
|
|
def build_bound_ip_display(self, binding_mode: str, allowed_ips: list[str]) -> str:
|
|
if binding_mode == BINDING_MODE_ALL:
|
|
return "ALL"
|
|
if not allowed_ips:
|
|
return "-"
|
|
if binding_mode == BINDING_MODE_MULTIPLE:
|
|
return ", ".join(allowed_ips)
|
|
return allowed_ips[0]
|
|
|
|
def is_client_allowed(self, client_ip: str, binding_mode: str, allowed_ips: list[str]) -> bool:
|
|
if binding_mode == BINDING_MODE_ALL:
|
|
return True
|
|
return any(is_ip_in_network(client_ip, item) for item in allowed_ips)
|
|
|
|
def to_binding_record(self, binding: TokenBinding, client_ip: str) -> BindingRecord:
|
|
allowed_ips = [str(item) for item in binding.allowed_ips]
|
|
binding_mode = binding.binding_mode or BINDING_MODE_SINGLE
|
|
return BindingRecord(
|
|
id=binding.id,
|
|
token_hash=binding.token_hash,
|
|
token_display=binding.token_display,
|
|
bound_ip=binding.bound_ip,
|
|
binding_mode=binding_mode,
|
|
allowed_ips=allowed_ips,
|
|
status=binding.status,
|
|
ip_matched=self.is_client_allowed(client_ip, binding_mode, allowed_ips),
|
|
)
|
|
|
|
def denied_result(
|
|
self,
|
|
token_hash: str,
|
|
token_display: str,
|
|
bound_ip: str,
|
|
detail: str,
|
|
*,
|
|
should_alert: bool = True,
|
|
status_code: int = 403,
|
|
) -> BindingCheckResult:
|
|
return BindingCheckResult(
|
|
allowed=False,
|
|
status_code=status_code,
|
|
detail=detail,
|
|
token_hash=token_hash,
|
|
token_display=token_display,
|
|
bound_ip=bound_ip,
|
|
should_alert=should_alert,
|
|
)
|
|
|
|
def allowed_result(
|
|
self,
|
|
token_hash: str,
|
|
token_display: str,
|
|
bound_ip: str,
|
|
detail: str,
|
|
*,
|
|
newly_bound: bool = False,
|
|
) -> BindingCheckResult:
|
|
return BindingCheckResult(
|
|
allowed=True,
|
|
status_code=200,
|
|
detail=detail,
|
|
token_hash=token_hash,
|
|
token_display=token_display,
|
|
bound_ip=bound_ip,
|
|
newly_bound=newly_bound,
|
|
)
|
|
|
|
def evaluate_existing_record(
|
|
self,
|
|
record: BindingRecord,
|
|
token_hash: str,
|
|
token_display: str,
|
|
detail: str,
|
|
) -> BindingCheckResult:
|
|
if record.status == STATUS_BANNED:
|
|
return self.denied_result(token_hash, token_display, record.bound_ip, "Token is banned.")
|
|
if record.ip_matched:
|
|
self.record_last_used(token_hash)
|
|
return self.allowed_result(token_hash, token_display, record.bound_ip, detail)
|
|
return self.denied_result(
|
|
token_hash,
|
|
token_display,
|
|
record.bound_ip,
|
|
"Client IP does not match the allowed binding rule.",
|
|
)
|
|
|
|
async def evaluate_token_binding(self, token: str, client_ip: str) -> BindingCheckResult:
|
|
token_hash = hash_token(token, self.settings.sentinel_hmac_secret)
|
|
token_display = mask_token(token)
|
|
|
|
cache_hit, cache_available = await self._load_binding_from_cache(token_hash, client_ip)
|
|
if cache_hit is not None:
|
|
if cache_hit.ip_matched:
|
|
await self._touch_cache(token_hash)
|
|
return self.evaluate_existing_record(cache_hit, token_hash, token_display, "Allowed from cache.")
|
|
|
|
if not cache_available:
|
|
logger.warning("Redis is unavailable. Falling back to PostgreSQL for token binding.")
|
|
if not await self._redis_degraded_limiter.allow():
|
|
logger.warning("Redis degraded limiter rejected a request during PostgreSQL fallback.")
|
|
return BindingCheckResult(
|
|
allowed=False,
|
|
status_code=429,
|
|
detail="Redis degraded mode rate limit reached.",
|
|
token_hash=token_hash,
|
|
token_display=token_display,
|
|
)
|
|
|
|
try:
|
|
record = await self._load_binding_from_db(token_hash, client_ip)
|
|
except SQLAlchemyError:
|
|
return self._handle_backend_failure(token_hash, token_display)
|
|
|
|
if record is not None:
|
|
await self.sync_binding_cache(record.token_hash, record.bound_ip, record.binding_mode, record.allowed_ips, record.status)
|
|
return self.evaluate_existing_record(record, token_hash, token_display, "Allowed from PostgreSQL.")
|
|
|
|
try:
|
|
created = await self._create_binding(token_hash, token_display, client_ip)
|
|
except SQLAlchemyError:
|
|
return self._handle_backend_failure(token_hash, token_display)
|
|
|
|
if created is None:
|
|
try:
|
|
existing = await self._load_binding_from_db(token_hash, client_ip)
|
|
except SQLAlchemyError:
|
|
return self._handle_backend_failure(token_hash, token_display)
|
|
if existing is None:
|
|
return self._handle_backend_failure(token_hash, token_display)
|
|
await self.sync_binding_cache(
|
|
existing.token_hash,
|
|
existing.bound_ip,
|
|
existing.binding_mode,
|
|
existing.allowed_ips,
|
|
existing.status,
|
|
)
|
|
return self.evaluate_existing_record(existing, token_hash, token_display, "Allowed after concurrent bind resolution.")
|
|
|
|
await self.sync_binding_cache(
|
|
created.token_hash,
|
|
created.bound_ip,
|
|
created.binding_mode,
|
|
created.allowed_ips,
|
|
created.status,
|
|
)
|
|
return self.allowed_result(token_hash, token_display, created.bound_ip, "First-use bind created.", newly_bound=True)
|
|
|
|
async def sync_binding_cache(
|
|
self,
|
|
token_hash: str,
|
|
bound_ip: str,
|
|
binding_mode: str,
|
|
allowed_ips: list[str],
|
|
status_code: int,
|
|
) -> None:
|
|
if self.redis is None:
|
|
return
|
|
payload = json.dumps(
|
|
{
|
|
"bound_ip": bound_ip,
|
|
"binding_mode": binding_mode,
|
|
"allowed_ips": allowed_ips,
|
|
"status": status_code,
|
|
}
|
|
)
|
|
try:
|
|
await self.redis.set(self.cache_key(token_hash), payload, ex=self.settings.redis_binding_ttl_seconds)
|
|
except Exception:
|
|
logger.warning("Failed to write binding cache.", extra={"token_hash": token_hash})
|
|
|
|
async def invalidate_binding_cache(self, token_hash: str) -> None:
|
|
if self.redis is None:
|
|
return
|
|
try:
|
|
await self.redis.delete(self.cache_key(token_hash))
|
|
except Exception:
|
|
logger.warning("Failed to delete binding cache.", extra={"token_hash": token_hash})
|
|
|
|
async def invalidate_many(self, token_hashes: list[str]) -> None:
|
|
if self.redis is None or not token_hashes:
|
|
return
|
|
keys = [self.cache_key(item) for item in token_hashes]
|
|
try:
|
|
await self.redis.delete(*keys)
|
|
except Exception:
|
|
logger.warning("Failed to delete multiple binding cache keys.", extra={"count": len(keys)})
|
|
|
|
def record_last_used(self, token_hash: str) -> None:
|
|
try:
|
|
self.last_used_queue.put_nowait(token_hash)
|
|
except asyncio.QueueFull:
|
|
logger.warning("last_used queue is full; dropping update.", extra={"token_hash": token_hash})
|
|
|
|
async def flush_last_used_updates(self) -> None:
|
|
token_hashes: set[str] = set()
|
|
while True:
|
|
try:
|
|
token_hashes.add(self.last_used_queue.get_nowait())
|
|
except asyncio.QueueEmpty:
|
|
break
|
|
|
|
if not token_hashes:
|
|
return
|
|
|
|
async with self.session_factory() as session:
|
|
try:
|
|
stmt = (
|
|
update(TokenBinding)
|
|
.where(TokenBinding.token_hash.in_(token_hashes))
|
|
.values(last_used_at=func.now())
|
|
)
|
|
await session.execute(stmt)
|
|
await session.commit()
|
|
except SQLAlchemyError:
|
|
await session.rollback()
|
|
logger.exception("Failed to flush last_used_at updates.", extra={"count": len(token_hashes)})
|
|
|
|
async def increment_request_metric(self, outcome: str | None) -> None:
|
|
if self.redis is None:
|
|
return
|
|
key = self.metrics_key(date.today())
|
|
ttl = self.settings.metrics_ttl_days * 86400
|
|
try:
|
|
async with self.redis.pipeline(transaction=True) as pipeline:
|
|
pipeline.hincrby(key, "total", 1)
|
|
if outcome in {"allowed", "intercepted"}:
|
|
pipeline.hincrby(key, outcome, 1)
|
|
pipeline.expire(key, ttl)
|
|
await pipeline.execute()
|
|
except Exception:
|
|
logger.warning("Failed to increment request metrics.", extra={"outcome": outcome})
|
|
|
|
async def get_metrics_window(self, days: int = 7) -> list[dict[str, int | str]]:
|
|
if self.redis is None:
|
|
return [
|
|
{"date": (date.today() - timedelta(days=offset)).isoformat(), "allowed": 0, "intercepted": 0, "total": 0}
|
|
for offset in range(days - 1, -1, -1)
|
|
]
|
|
|
|
series: list[dict[str, int | str]] = []
|
|
for offset in range(days - 1, -1, -1):
|
|
target = date.today() - timedelta(days=offset)
|
|
raw = await self.redis.hgetall(self.metrics_key(target))
|
|
series.append(
|
|
{
|
|
"date": target.isoformat(),
|
|
"allowed": int(raw.get("allowed", 0)),
|
|
"intercepted": int(raw.get("intercepted", 0)),
|
|
"total": int(raw.get("total", 0)),
|
|
}
|
|
)
|
|
return series
|
|
|
|
async def _load_binding_from_cache(self, token_hash: str, client_ip: str) -> tuple[BindingRecord | None, bool]:
|
|
if self.redis is None:
|
|
return None, False
|
|
try:
|
|
raw = await self.redis.get(self.cache_key(token_hash))
|
|
except Exception:
|
|
logger.warning("Failed to read binding cache.", extra={"token_hash": token_hash})
|
|
return None, False
|
|
if raw is None:
|
|
return None, True
|
|
|
|
data = json.loads(raw)
|
|
allowed_ips = [str(item) for item in data.get("allowed_ips", [])]
|
|
binding_mode = str(data.get("binding_mode", BINDING_MODE_SINGLE))
|
|
return (
|
|
BindingRecord(
|
|
id=0,
|
|
token_hash=token_hash,
|
|
token_display="",
|
|
bound_ip=str(data.get("bound_ip", self.build_bound_ip_display(binding_mode, allowed_ips))),
|
|
binding_mode=binding_mode,
|
|
allowed_ips=allowed_ips,
|
|
status=int(data["status"]),
|
|
ip_matched=self.is_client_allowed(client_ip, binding_mode, allowed_ips),
|
|
),
|
|
True,
|
|
)
|
|
|
|
async def _touch_cache(self, token_hash: str) -> None:
|
|
if self.redis is None:
|
|
return
|
|
try:
|
|
await self.redis.expire(self.cache_key(token_hash), self.settings.redis_binding_ttl_seconds)
|
|
except Exception:
|
|
logger.warning("Failed to extend binding cache TTL.", extra={"token_hash": token_hash})
|
|
|
|
async def _load_binding_from_db(self, token_hash: str, client_ip: str) -> BindingRecord | None:
|
|
async with self.session_factory() as session:
|
|
binding = await session.scalar(select(TokenBinding).where(TokenBinding.token_hash == token_hash).limit(1))
|
|
if binding is None:
|
|
return None
|
|
return self.to_binding_record(binding, client_ip)
|
|
|
|
async def _create_binding(self, token_hash: str, token_display: str, client_ip: str) -> BindingRecord | None:
|
|
async with self.session_factory() as session:
|
|
try:
|
|
binding = TokenBinding(
|
|
token_hash=token_hash,
|
|
token_display=token_display,
|
|
bound_ip=client_ip,
|
|
binding_mode=BINDING_MODE_SINGLE,
|
|
allowed_ips=[client_ip],
|
|
status=STATUS_ACTIVE,
|
|
)
|
|
session.add(binding)
|
|
await session.flush()
|
|
await session.commit()
|
|
await session.refresh(binding)
|
|
except SQLAlchemyError as exc:
|
|
await session.rollback()
|
|
if "duplicate key" in str(exc).lower() or "unique" in str(exc).lower():
|
|
return None
|
|
raise
|
|
return self.to_binding_record(binding, client_ip)
|
|
|
|
def _handle_backend_failure(self, token_hash: str, token_display: str) -> BindingCheckResult:
|
|
runtime_settings = self.runtime_settings_getter()
|
|
logger.exception(
|
|
"Binding storage backend failed.",
|
|
extra={"failsafe_mode": runtime_settings.failsafe_mode, "token_hash": token_hash},
|
|
)
|
|
if runtime_settings.failsafe_mode == "open":
|
|
return BindingCheckResult(
|
|
allowed=True,
|
|
status_code=200,
|
|
detail="Allowed by failsafe mode.",
|
|
token_hash=token_hash,
|
|
token_display=token_display,
|
|
)
|
|
return BindingCheckResult(
|
|
allowed=False,
|
|
status_code=503,
|
|
detail="Binding backend unavailable and failsafe mode is closed.",
|
|
token_hash=token_hash,
|
|
token_display=token_display,
|
|
)
|
|
|
|
async def _flush_loop(self) -> None:
|
|
try:
|
|
while not self._stop_event.is_set():
|
|
await asyncio.sleep(self.settings.last_used_flush_interval_seconds)
|
|
await self.flush_last_used_updates()
|
|
except asyncio.CancelledError:
|
|
raise
|