from __future__ import annotations import asyncio import json import logging import time from dataclasses import dataclass from datetime import UTC, date, timedelta from typing import Callable from redis.asyncio import Redis from sqlalchemy import func, select, text, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from app.config import RuntimeSettings, Settings from app.core.ip_utils import is_ip_in_network from app.core.security import hash_token, mask_token from app.models.token_binding import STATUS_ACTIVE, STATUS_BANNED, TokenBinding logger = logging.getLogger(__name__) @dataclass(slots=True) class BindingRecord: id: int token_hash: str token_display: str bound_ip: str status: int ip_matched: bool @dataclass(slots=True) class BindingCheckResult: allowed: bool status_code: int detail: str token_hash: str | None = None token_display: str | None = None bound_ip: str | None = None should_alert: bool = False newly_bound: bool = False class InMemoryRateLimiter: def __init__(self, max_per_second: int) -> None: self.max_per_second = max(1, max_per_second) self._window = int(time.monotonic()) self._count = 0 self._lock = asyncio.Lock() async def allow(self) -> bool: async with self._lock: current_window = int(time.monotonic()) if current_window != self._window: self._window = current_window self._count = 0 if self._count >= self.max_per_second: return False self._count += 1 return True class BindingService: def __init__( self, settings: Settings, session_factory: async_sessionmaker[AsyncSession], redis: Redis | None, runtime_settings_getter: Callable[[], RuntimeSettings], ) -> None: self.settings = settings self.session_factory = session_factory self.redis = redis self.runtime_settings_getter = runtime_settings_getter self.last_used_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=settings.last_used_queue_size) self._flush_task: asyncio.Task[None] | None = None self._stop_event = asyncio.Event() self._redis_degraded_limiter = InMemoryRateLimiter(settings.downstream_max_connections // 2) async def start(self) -> None: if self._flush_task is None: self._stop_event.clear() self._flush_task = asyncio.create_task(self._flush_loop(), name="binding-last-used-flush") async def stop(self) -> None: self._stop_event.set() if self._flush_task is not None: self._flush_task.cancel() try: await self._flush_task except asyncio.CancelledError: pass self._flush_task = None await self.flush_last_used_updates() def status_label(self, status_code: int) -> str: return "Active" if status_code == STATUS_ACTIVE else "Banned" def cache_key(self, token_hash: str) -> str: return f"sentinel:binding:{token_hash}" def metrics_key(self, target_date: date) -> str: return f"sentinel:metrics:{target_date.isoformat()}" async def evaluate_token_binding(self, token: str, client_ip: str) -> BindingCheckResult: token_hash = hash_token(token, self.settings.sentinel_hmac_secret) token_display = mask_token(token) cache_hit, cache_available = await self._load_binding_from_cache(token_hash) if cache_hit is not None: if cache_hit.status == STATUS_BANNED: return BindingCheckResult( allowed=False, status_code=403, detail="Token is banned.", token_hash=token_hash, token_display=token_display, bound_ip=cache_hit.bound_ip, should_alert=True, ) if is_ip_in_network(client_ip, cache_hit.bound_ip): await self._touch_cache(token_hash) self.record_last_used(token_hash) return BindingCheckResult( allowed=True, status_code=200, detail="Allowed from cache.", token_hash=token_hash, token_display=token_display, bound_ip=cache_hit.bound_ip, ) return BindingCheckResult( allowed=False, status_code=403, detail="Client IP does not match the bound CIDR.", token_hash=token_hash, token_display=token_display, bound_ip=cache_hit.bound_ip, should_alert=True, ) if not cache_available: logger.warning("Redis is unavailable. Falling back to PostgreSQL for token binding.") if not await self._redis_degraded_limiter.allow(): logger.warning("Redis degraded limiter rejected a request during PostgreSQL fallback.") return BindingCheckResult( allowed=False, status_code=429, detail="Redis degraded mode rate limit reached.", token_hash=token_hash, token_display=token_display, ) try: record = await self._load_binding_from_db(token_hash, client_ip) except SQLAlchemyError: return self._handle_backend_failure(token_hash, token_display) if record is not None: await self.sync_binding_cache(record.token_hash, record.bound_ip, record.status) if record.status == STATUS_BANNED: return BindingCheckResult( allowed=False, status_code=403, detail="Token is banned.", token_hash=token_hash, token_display=token_display, bound_ip=record.bound_ip, should_alert=True, ) if record.ip_matched: self.record_last_used(token_hash) return BindingCheckResult( allowed=True, status_code=200, detail="Allowed from PostgreSQL.", token_hash=token_hash, token_display=token_display, bound_ip=record.bound_ip, ) return BindingCheckResult( allowed=False, status_code=403, detail="Client IP does not match the bound CIDR.", token_hash=token_hash, token_display=token_display, bound_ip=record.bound_ip, should_alert=True, ) try: created = await self._create_binding(token_hash, token_display, client_ip) except SQLAlchemyError: return self._handle_backend_failure(token_hash, token_display) if created is None: try: existing = await self._load_binding_from_db(token_hash, client_ip) except SQLAlchemyError: return self._handle_backend_failure(token_hash, token_display) if existing is None: return self._handle_backend_failure(token_hash, token_display) await self.sync_binding_cache(existing.token_hash, existing.bound_ip, existing.status) if existing.status == STATUS_BANNED: return BindingCheckResult( allowed=False, status_code=403, detail="Token is banned.", token_hash=token_hash, token_display=token_display, bound_ip=existing.bound_ip, should_alert=True, ) if existing.ip_matched: self.record_last_used(token_hash) return BindingCheckResult( allowed=True, status_code=200, detail="Allowed after concurrent bind resolution.", token_hash=token_hash, token_display=token_display, bound_ip=existing.bound_ip, ) return BindingCheckResult( allowed=False, status_code=403, detail="Client IP does not match the bound CIDR.", token_hash=token_hash, token_display=token_display, bound_ip=existing.bound_ip, should_alert=True, ) await self.sync_binding_cache(created.token_hash, created.bound_ip, created.status) return BindingCheckResult( allowed=True, status_code=200, detail="First-use bind created.", token_hash=token_hash, token_display=token_display, bound_ip=created.bound_ip, newly_bound=True, ) async def sync_binding_cache(self, token_hash: str, bound_ip: str, status_code: int) -> None: if self.redis is None: return payload = json.dumps({"bound_ip": bound_ip, "status": status_code}) try: await self.redis.set(self.cache_key(token_hash), payload, ex=self.settings.redis_binding_ttl_seconds) except Exception: logger.warning("Failed to write binding cache.", extra={"token_hash": token_hash}) async def invalidate_binding_cache(self, token_hash: str) -> None: if self.redis is None: return try: await self.redis.delete(self.cache_key(token_hash)) except Exception: logger.warning("Failed to delete binding cache.", extra={"token_hash": token_hash}) async def invalidate_many(self, token_hashes: list[str]) -> None: if self.redis is None or not token_hashes: return keys = [self.cache_key(item) for item in token_hashes] try: await self.redis.delete(*keys) except Exception: logger.warning("Failed to delete multiple binding cache keys.", extra={"count": len(keys)}) def record_last_used(self, token_hash: str) -> None: try: self.last_used_queue.put_nowait(token_hash) except asyncio.QueueFull: logger.warning("last_used queue is full; dropping update.", extra={"token_hash": token_hash}) async def flush_last_used_updates(self) -> None: token_hashes: set[str] = set() while True: try: token_hashes.add(self.last_used_queue.get_nowait()) except asyncio.QueueEmpty: break if not token_hashes: return async with self.session_factory() as session: try: stmt = ( update(TokenBinding) .where(TokenBinding.token_hash.in_(token_hashes)) .values(last_used_at=func.now()) ) await session.execute(stmt) await session.commit() except SQLAlchemyError: await session.rollback() logger.exception("Failed to flush last_used_at updates.", extra={"count": len(token_hashes)}) async def increment_request_metric(self, outcome: str | None) -> None: if self.redis is None: return key = self.metrics_key(date.today()) ttl = self.settings.metrics_ttl_days * 86400 try: async with self.redis.pipeline(transaction=True) as pipeline: pipeline.hincrby(key, "total", 1) if outcome in {"allowed", "intercepted"}: pipeline.hincrby(key, outcome, 1) pipeline.expire(key, ttl) await pipeline.execute() except Exception: logger.warning("Failed to increment request metrics.", extra={"outcome": outcome}) async def get_metrics_window(self, days: int = 7) -> list[dict[str, int | str]]: if self.redis is None: return [ {"date": (date.today() - timedelta(days=offset)).isoformat(), "allowed": 0, "intercepted": 0, "total": 0} for offset in range(days - 1, -1, -1) ] series: list[dict[str, int | str]] = [] for offset in range(days - 1, -1, -1): target = date.today() - timedelta(days=offset) raw = await self.redis.hgetall(self.metrics_key(target)) series.append( { "date": target.isoformat(), "allowed": int(raw.get("allowed", 0)), "intercepted": int(raw.get("intercepted", 0)), "total": int(raw.get("total", 0)), } ) return series async def _load_binding_from_cache(self, token_hash: str) -> tuple[BindingRecord | None, bool]: if self.redis is None: return None, False try: raw = await self.redis.get(self.cache_key(token_hash)) except Exception: logger.warning("Failed to read binding cache.", extra={"token_hash": token_hash}) return None, False if raw is None: return None, True data = json.loads(raw) return ( BindingRecord( id=0, token_hash=token_hash, token_display="", bound_ip=data["bound_ip"], status=int(data["status"]), ip_matched=False, ), True, ) async def _touch_cache(self, token_hash: str) -> None: if self.redis is None: return try: await self.redis.expire(self.cache_key(token_hash), self.settings.redis_binding_ttl_seconds) except Exception: logger.warning("Failed to extend binding cache TTL.", extra={"token_hash": token_hash}) async def _load_binding_from_db(self, token_hash: str, client_ip: str) -> BindingRecord | None: query = text( """ SELECT id, token_hash, token_display, bound_ip::text AS bound_ip, status, CAST(:client_ip AS inet) << bound_ip AS ip_matched FROM token_bindings WHERE token_hash = :token_hash LIMIT 1 """ ) async with self.session_factory() as session: result = await session.execute(query, {"token_hash": token_hash, "client_ip": client_ip}) row = result.mappings().first() if row is None: return None return BindingRecord( id=int(row["id"]), token_hash=str(row["token_hash"]), token_display=str(row["token_display"]), bound_ip=str(row["bound_ip"]), status=int(row["status"]), ip_matched=bool(row["ip_matched"]), ) async def _create_binding(self, token_hash: str, token_display: str, client_ip: str) -> BindingRecord | None: statement = text( """ INSERT INTO token_bindings (token_hash, token_display, bound_ip, status) VALUES (:token_hash, :token_display, CAST(:bound_ip AS cidr), :status) ON CONFLICT (token_hash) DO NOTHING RETURNING id, token_hash, token_display, bound_ip::text AS bound_ip, status """ ) async with self.session_factory() as session: try: result = await session.execute( statement, { "token_hash": token_hash, "token_display": token_display, "bound_ip": client_ip, "status": STATUS_ACTIVE, }, ) row = result.mappings().first() await session.commit() except SQLAlchemyError: await session.rollback() raise if row is None: return None return BindingRecord( id=int(row["id"]), token_hash=str(row["token_hash"]), token_display=str(row["token_display"]), bound_ip=str(row["bound_ip"]), status=int(row["status"]), ip_matched=True, ) def _handle_backend_failure(self, token_hash: str, token_display: str) -> BindingCheckResult: runtime_settings = self.runtime_settings_getter() logger.exception( "Binding storage backend failed.", extra={"failsafe_mode": runtime_settings.failsafe_mode, "token_hash": token_hash}, ) if runtime_settings.failsafe_mode == "open": return BindingCheckResult( allowed=True, status_code=200, detail="Allowed by failsafe mode.", token_hash=token_hash, token_display=token_display, ) return BindingCheckResult( allowed=False, status_code=503, detail="Binding backend unavailable and failsafe mode is closed.", token_hash=token_hash, token_display=token_display, ) async def _flush_loop(self) -> None: try: while not self._stop_event.is_set(): await asyncio.sleep(self.settings.last_used_flush_interval_seconds) await self.flush_last_used_updates() except asyncio.CancelledError: raise