Add multi-IP binding modes and deployment guide
This commit is contained in:
@@ -28,6 +28,8 @@ def to_binding_item(binding: TokenBinding, binding_service: BindingService) -> B
|
||||
id=binding.id,
|
||||
token_display=binding.token_display,
|
||||
bound_ip=str(binding.bound_ip),
|
||||
binding_mode=binding.binding_mode,
|
||||
allowed_ips=[str(item) for item in binding.allowed_ips],
|
||||
status=binding.status,
|
||||
status_label=binding_service.status_label(binding.status),
|
||||
first_used_at=binding.first_used_at,
|
||||
@@ -70,7 +72,13 @@ def log_admin_action(request: Request, settings: Settings, action: str, binding_
|
||||
|
||||
|
||||
async def commit_binding_cache(binding: TokenBinding, binding_service: BindingService) -> None:
|
||||
await binding_service.sync_binding_cache(binding.token_hash, str(binding.bound_ip), binding.status)
|
||||
await binding_service.sync_binding_cache(
|
||||
binding.token_hash,
|
||||
str(binding.bound_ip),
|
||||
binding.binding_mode,
|
||||
[str(item) for item in binding.allowed_ips],
|
||||
binding.status,
|
||||
)
|
||||
|
||||
|
||||
async def update_binding_status(
|
||||
@@ -138,7 +146,9 @@ async def update_bound_ip(
|
||||
binding_service: BindingService = Depends(get_binding_service),
|
||||
):
|
||||
binding = await get_binding_or_404(session, payload.id)
|
||||
binding.bound_ip = payload.bound_ip
|
||||
binding.binding_mode = payload.binding_mode
|
||||
binding.allowed_ips = payload.allowed_ips
|
||||
binding.bound_ip = binding_service.build_bound_ip_display(payload.binding_mode, payload.allowed_ips)
|
||||
await session.commit()
|
||||
await commit_binding_cache(binding, binding_service)
|
||||
log_admin_action(request, settings, "update_ip", payload.id)
|
||||
|
||||
@@ -76,7 +76,7 @@ async def build_recent_intercepts(session: AsyncSession) -> list[InterceptLogIte
|
||||
InterceptLogItem(
|
||||
id=item.id,
|
||||
token_display=item.token_display,
|
||||
bound_ip=str(item.bound_ip),
|
||||
bound_ip=item.bound_ip,
|
||||
attempt_ip=str(item.attempt_ip),
|
||||
alerted=item.alerted,
|
||||
intercepted_at=item.intercepted_at,
|
||||
|
||||
@@ -38,7 +38,7 @@ def to_log_item(item: InterceptLog) -> InterceptLogItem:
|
||||
return InterceptLogItem(
|
||||
id=item.id,
|
||||
token_display=item.token_display,
|
||||
bound_ip=str(item.bound_ip),
|
||||
bound_ip=item.bound_ip,
|
||||
attempt_ip=str(item.attempt_ip),
|
||||
alerted=item.alerted,
|
||||
intercepted_at=item.intercepted_at,
|
||||
@@ -47,13 +47,13 @@ def to_log_item(item: InterceptLog) -> InterceptLogItem:
|
||||
|
||||
def write_log_csv(buffer: io.StringIO, logs: list[InterceptLog]) -> None:
|
||||
writer = csv.writer(buffer)
|
||||
writer.writerow(["id", "token_display", "bound_ip", "attempt_ip", "alerted", "intercepted_at"])
|
||||
writer.writerow(["id", "token_display", "binding_rule", "attempt_ip", "alerted", "intercepted_at"])
|
||||
for item in logs:
|
||||
writer.writerow(
|
||||
[
|
||||
item.id,
|
||||
item.token_display,
|
||||
str(item.bound_ip),
|
||||
item.bound_ip,
|
||||
str(item.attempt_ip),
|
||||
item.alerted,
|
||||
item.intercepted_at.isoformat(),
|
||||
|
||||
@@ -14,7 +14,7 @@ from redis.asyncio import from_url as redis_from_url
|
||||
from app.api import auth, bindings, dashboard, logs, settings as settings_api
|
||||
from app.config import RUNTIME_SETTINGS_REDIS_KEY, RuntimeSettings, Settings, get_settings
|
||||
from app.models import intercept_log, token_binding # noqa: F401
|
||||
from app.models.db import close_db, get_session_factory, init_db
|
||||
from app.models.db import close_db, ensure_schema_compatibility, get_session_factory, init_db
|
||||
from app.proxy.handler import router as proxy_router
|
||||
from app.services.alert_service import AlertService
|
||||
from app.services.archive_service import ArchiveService
|
||||
@@ -100,6 +100,7 @@ async def load_runtime_settings(redis: Redis | None, settings: Settings) -> Runt
|
||||
async def lifespan(app: FastAPI):
|
||||
settings = get_settings()
|
||||
init_db(settings)
|
||||
await ensure_schema_compatibility()
|
||||
session_factory = get_session_factory()
|
||||
|
||||
redis: Redis | None = redis_from_url(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
@@ -40,6 +41,31 @@ def get_session_factory() -> async_sessionmaker[AsyncSession]:
|
||||
return _session_factory
|
||||
|
||||
|
||||
async def ensure_schema_compatibility() -> None:
|
||||
engine = get_engine()
|
||||
statements = [
|
||||
"DROP INDEX IF EXISTS idx_token_bindings_ip",
|
||||
"ALTER TABLE token_bindings ALTER COLUMN bound_ip TYPE TEXT USING bound_ip::text",
|
||||
"ALTER TABLE intercept_logs ALTER COLUMN bound_ip TYPE TEXT USING bound_ip::text",
|
||||
"ALTER TABLE token_bindings ADD COLUMN IF NOT EXISTS binding_mode VARCHAR(16) DEFAULT 'single'",
|
||||
"ALTER TABLE token_bindings ADD COLUMN IF NOT EXISTS allowed_ips JSONB DEFAULT '[]'::jsonb",
|
||||
"UPDATE token_bindings SET binding_mode = 'single' WHERE binding_mode IS NULL OR binding_mode = ''",
|
||||
"""
|
||||
UPDATE token_bindings
|
||||
SET allowed_ips = jsonb_build_array(bound_ip)
|
||||
WHERE allowed_ips IS NULL OR allowed_ips = '[]'::jsonb
|
||||
""",
|
||||
"ALTER TABLE token_bindings ALTER COLUMN binding_mode SET NOT NULL",
|
||||
"ALTER TABLE token_bindings ALTER COLUMN allowed_ips SET NOT NULL",
|
||||
"ALTER TABLE token_bindings ALTER COLUMN binding_mode SET DEFAULT 'single'",
|
||||
"ALTER TABLE token_bindings ALTER COLUMN allowed_ips SET DEFAULT '[]'::jsonb",
|
||||
"CREATE INDEX IF NOT EXISTS idx_token_bindings_ip ON token_bindings(bound_ip)",
|
||||
]
|
||||
async with engine.begin() as connection:
|
||||
for statement in statements:
|
||||
await connection.execute(text(statement))
|
||||
|
||||
|
||||
async def close_db() -> None:
|
||||
global _engine, _session_factory
|
||||
if _engine is not None:
|
||||
|
||||
@@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Index, String, func, text
|
||||
from sqlalchemy.dialects.postgresql import CIDR, INET
|
||||
from sqlalchemy import Boolean, DateTime, Index, String, Text, func, text
|
||||
from sqlalchemy.dialects.postgresql import INET
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.db import Base
|
||||
@@ -19,7 +19,7 @@ class InterceptLog(Base):
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
token_hash: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
token_display: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
bound_ip: Mapped[str] = mapped_column(CIDR, nullable=False)
|
||||
bound_ip: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
attempt_ip: Mapped[str] = mapped_column(INET, nullable=False)
|
||||
alerted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default=text("FALSE"))
|
||||
intercepted_at: Mapped[datetime] = mapped_column(
|
||||
|
||||
@@ -2,27 +2,42 @@ from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, Index, SmallInteger, String, func, text
|
||||
from sqlalchemy.dialects.postgresql import CIDR
|
||||
from sqlalchemy import DateTime, Index, SmallInteger, String, Text, func, text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from app.models.db import Base
|
||||
|
||||
STATUS_ACTIVE = 1
|
||||
STATUS_BANNED = 2
|
||||
BINDING_MODE_SINGLE = "single"
|
||||
BINDING_MODE_MULTIPLE = "multiple"
|
||||
BINDING_MODE_ALL = "all"
|
||||
|
||||
|
||||
class TokenBinding(Base):
|
||||
__tablename__ = "token_bindings"
|
||||
__table_args__ = (
|
||||
Index("idx_token_bindings_hash", "token_hash"),
|
||||
Index("idx_token_bindings_ip", "bound_ip", postgresql_using="gist", postgresql_ops={"bound_ip": "inet_ops"}),
|
||||
Index("idx_token_bindings_ip", "bound_ip"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
|
||||
token_display: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
bound_ip: Mapped[str] = mapped_column(CIDR, nullable=False)
|
||||
bound_ip: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
binding_mode: Mapped[str] = mapped_column(
|
||||
String(16),
|
||||
nullable=False,
|
||||
default=BINDING_MODE_SINGLE,
|
||||
server_default=text("'single'"),
|
||||
)
|
||||
allowed_ips: Mapped[list[str]] = mapped_column(
|
||||
JSONB,
|
||||
nullable=False,
|
||||
default=list,
|
||||
server_default=text("'[]'::jsonb"),
|
||||
)
|
||||
status: Mapped[int] = mapped_column(
|
||||
SmallInteger,
|
||||
nullable=False,
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from ipaddress import ip_address, ip_network
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from app.models.token_binding import BINDING_MODE_ALL, BINDING_MODE_MULTIPLE, BINDING_MODE_SINGLE
|
||||
|
||||
|
||||
class BindingItem(BaseModel):
|
||||
@@ -11,6 +14,8 @@ class BindingItem(BaseModel):
|
||||
id: int
|
||||
token_display: str
|
||||
bound_ip: str
|
||||
binding_mode: str
|
||||
allowed_ips: list[str]
|
||||
status: int
|
||||
status_label: str
|
||||
first_used_at: datetime
|
||||
@@ -31,12 +36,32 @@ class BindingActionRequest(BaseModel):
|
||||
|
||||
class BindingIPUpdateRequest(BaseModel):
|
||||
id: int = Field(gt=0)
|
||||
bound_ip: str = Field(min_length=3, max_length=64)
|
||||
binding_mode: str = Field(default=BINDING_MODE_SINGLE)
|
||||
allowed_ips: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("bound_ip")
|
||||
@classmethod
|
||||
def validate_bound_ip(cls, value: str) -> str:
|
||||
from ipaddress import ip_network
|
||||
@model_validator(mode="after")
|
||||
def validate_binding_rule(self):
|
||||
allowed_ips = [item.strip() for item in self.allowed_ips if item.strip()]
|
||||
|
||||
ip_network(value, strict=False)
|
||||
return value
|
||||
if self.binding_mode == BINDING_MODE_ALL:
|
||||
self.allowed_ips = []
|
||||
return self
|
||||
|
||||
if self.binding_mode == BINDING_MODE_SINGLE:
|
||||
if len(allowed_ips) != 1:
|
||||
raise ValueError("Single binding mode requires exactly one IP or CIDR.")
|
||||
ip_network(allowed_ips[0], strict=False)
|
||||
self.allowed_ips = allowed_ips
|
||||
return self
|
||||
|
||||
if self.binding_mode == BINDING_MODE_MULTIPLE:
|
||||
if not allowed_ips:
|
||||
raise ValueError("Multiple binding mode requires at least one IP.")
|
||||
normalized: list[str] = []
|
||||
for item in allowed_ips:
|
||||
ip_address(item)
|
||||
normalized.append(item)
|
||||
self.allowed_ips = normalized
|
||||
return self
|
||||
|
||||
raise ValueError("Unsupported binding mode.")
|
||||
|
||||
@@ -5,18 +5,25 @@ import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, date, timedelta
|
||||
from datetime import date, timedelta
|
||||
from typing import Callable
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from sqlalchemy import func, select, text, update
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.config import RuntimeSettings, Settings
|
||||
from app.core.ip_utils import is_ip_in_network
|
||||
from app.core.security import hash_token, mask_token
|
||||
from app.models.token_binding import STATUS_ACTIVE, STATUS_BANNED, TokenBinding
|
||||
from app.models.token_binding import (
|
||||
BINDING_MODE_ALL,
|
||||
BINDING_MODE_MULTIPLE,
|
||||
BINDING_MODE_SINGLE,
|
||||
STATUS_ACTIVE,
|
||||
STATUS_BANNED,
|
||||
TokenBinding,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,6 +34,8 @@ class BindingRecord:
|
||||
token_hash: str
|
||||
token_display: str
|
||||
bound_ip: str
|
||||
binding_mode: str
|
||||
allowed_ips: list[str]
|
||||
status: int
|
||||
ip_matched: bool
|
||||
|
||||
@@ -104,42 +113,101 @@ class BindingService:
|
||||
def metrics_key(self, target_date: date) -> str:
|
||||
return f"sentinel:metrics:{target_date.isoformat()}"
|
||||
|
||||
def build_bound_ip_display(self, binding_mode: str, allowed_ips: list[str]) -> str:
|
||||
if binding_mode == BINDING_MODE_ALL:
|
||||
return "ALL"
|
||||
if not allowed_ips:
|
||||
return "-"
|
||||
if binding_mode == BINDING_MODE_MULTIPLE:
|
||||
return ", ".join(allowed_ips)
|
||||
return allowed_ips[0]
|
||||
|
||||
def is_client_allowed(self, client_ip: str, binding_mode: str, allowed_ips: list[str]) -> bool:
|
||||
if binding_mode == BINDING_MODE_ALL:
|
||||
return True
|
||||
return any(is_ip_in_network(client_ip, item) for item in allowed_ips)
|
||||
|
||||
def to_binding_record(self, binding: TokenBinding, client_ip: str) -> BindingRecord:
|
||||
allowed_ips = [str(item) for item in binding.allowed_ips]
|
||||
binding_mode = binding.binding_mode or BINDING_MODE_SINGLE
|
||||
return BindingRecord(
|
||||
id=binding.id,
|
||||
token_hash=binding.token_hash,
|
||||
token_display=binding.token_display,
|
||||
bound_ip=binding.bound_ip,
|
||||
binding_mode=binding_mode,
|
||||
allowed_ips=allowed_ips,
|
||||
status=binding.status,
|
||||
ip_matched=self.is_client_allowed(client_ip, binding_mode, allowed_ips),
|
||||
)
|
||||
|
||||
def denied_result(
|
||||
self,
|
||||
token_hash: str,
|
||||
token_display: str,
|
||||
bound_ip: str,
|
||||
detail: str,
|
||||
*,
|
||||
should_alert: bool = True,
|
||||
status_code: int = 403,
|
||||
) -> BindingCheckResult:
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=status_code,
|
||||
detail=detail,
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=bound_ip,
|
||||
should_alert=should_alert,
|
||||
)
|
||||
|
||||
def allowed_result(
|
||||
self,
|
||||
token_hash: str,
|
||||
token_display: str,
|
||||
bound_ip: str,
|
||||
detail: str,
|
||||
*,
|
||||
newly_bound: bool = False,
|
||||
) -> BindingCheckResult:
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail=detail,
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=bound_ip,
|
||||
newly_bound=newly_bound,
|
||||
)
|
||||
|
||||
def evaluate_existing_record(
|
||||
self,
|
||||
record: BindingRecord,
|
||||
token_hash: str,
|
||||
token_display: str,
|
||||
detail: str,
|
||||
) -> BindingCheckResult:
|
||||
if record.status == STATUS_BANNED:
|
||||
return self.denied_result(token_hash, token_display, record.bound_ip, "Token is banned.")
|
||||
if record.ip_matched:
|
||||
self.record_last_used(token_hash)
|
||||
return self.allowed_result(token_hash, token_display, record.bound_ip, detail)
|
||||
return self.denied_result(
|
||||
token_hash,
|
||||
token_display,
|
||||
record.bound_ip,
|
||||
"Client IP does not match the allowed binding rule.",
|
||||
)
|
||||
|
||||
async def evaluate_token_binding(self, token: str, client_ip: str) -> BindingCheckResult:
|
||||
token_hash = hash_token(token, self.settings.sentinel_hmac_secret)
|
||||
token_display = mask_token(token)
|
||||
|
||||
cache_hit, cache_available = await self._load_binding_from_cache(token_hash)
|
||||
cache_hit, cache_available = await self._load_binding_from_cache(token_hash, client_ip)
|
||||
if cache_hit is not None:
|
||||
if cache_hit.status == STATUS_BANNED:
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Token is banned.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=cache_hit.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
if is_ip_in_network(client_ip, cache_hit.bound_ip):
|
||||
if cache_hit.ip_matched:
|
||||
await self._touch_cache(token_hash)
|
||||
self.record_last_used(token_hash)
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail="Allowed from cache.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=cache_hit.bound_ip,
|
||||
)
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Client IP does not match the bound CIDR.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=cache_hit.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
return self.evaluate_existing_record(cache_hit, token_hash, token_display, "Allowed from cache.")
|
||||
|
||||
if not cache_available:
|
||||
logger.warning("Redis is unavailable. Falling back to PostgreSQL for token binding.")
|
||||
@@ -159,36 +227,8 @@ class BindingService:
|
||||
return self._handle_backend_failure(token_hash, token_display)
|
||||
|
||||
if record is not None:
|
||||
await self.sync_binding_cache(record.token_hash, record.bound_ip, record.status)
|
||||
if record.status == STATUS_BANNED:
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Token is banned.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=record.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
if record.ip_matched:
|
||||
self.record_last_used(token_hash)
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail="Allowed from PostgreSQL.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=record.bound_ip,
|
||||
)
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Client IP does not match the bound CIDR.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=record.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
await self.sync_binding_cache(record.token_hash, record.bound_ip, record.binding_mode, record.allowed_ips, record.status)
|
||||
return self.evaluate_existing_record(record, token_hash, token_display, "Allowed from PostgreSQL.")
|
||||
|
||||
try:
|
||||
created = await self._create_binding(token_hash, token_display, client_ip)
|
||||
@@ -202,52 +242,42 @@ class BindingService:
|
||||
return self._handle_backend_failure(token_hash, token_display)
|
||||
if existing is None:
|
||||
return self._handle_backend_failure(token_hash, token_display)
|
||||
await self.sync_binding_cache(existing.token_hash, existing.bound_ip, existing.status)
|
||||
if existing.status == STATUS_BANNED:
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Token is banned.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=existing.bound_ip,
|
||||
should_alert=True,
|
||||
)
|
||||
if existing.ip_matched:
|
||||
self.record_last_used(token_hash)
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail="Allowed after concurrent bind resolution.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=existing.bound_ip,
|
||||
)
|
||||
return BindingCheckResult(
|
||||
allowed=False,
|
||||
status_code=403,
|
||||
detail="Client IP does not match the bound CIDR.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=existing.bound_ip,
|
||||
should_alert=True,
|
||||
await self.sync_binding_cache(
|
||||
existing.token_hash,
|
||||
existing.bound_ip,
|
||||
existing.binding_mode,
|
||||
existing.allowed_ips,
|
||||
existing.status,
|
||||
)
|
||||
return self.evaluate_existing_record(existing, token_hash, token_display, "Allowed after concurrent bind resolution.")
|
||||
|
||||
await self.sync_binding_cache(created.token_hash, created.bound_ip, created.status)
|
||||
return BindingCheckResult(
|
||||
allowed=True,
|
||||
status_code=200,
|
||||
detail="First-use bind created.",
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=created.bound_ip,
|
||||
newly_bound=True,
|
||||
await self.sync_binding_cache(
|
||||
created.token_hash,
|
||||
created.bound_ip,
|
||||
created.binding_mode,
|
||||
created.allowed_ips,
|
||||
created.status,
|
||||
)
|
||||
return self.allowed_result(token_hash, token_display, created.bound_ip, "First-use bind created.", newly_bound=True)
|
||||
|
||||
async def sync_binding_cache(self, token_hash: str, bound_ip: str, status_code: int) -> None:
|
||||
async def sync_binding_cache(
|
||||
self,
|
||||
token_hash: str,
|
||||
bound_ip: str,
|
||||
binding_mode: str,
|
||||
allowed_ips: list[str],
|
||||
status_code: int,
|
||||
) -> None:
|
||||
if self.redis is None:
|
||||
return
|
||||
payload = json.dumps({"bound_ip": bound_ip, "status": status_code})
|
||||
payload = json.dumps(
|
||||
{
|
||||
"bound_ip": bound_ip,
|
||||
"binding_mode": binding_mode,
|
||||
"allowed_ips": allowed_ips,
|
||||
"status": status_code,
|
||||
}
|
||||
)
|
||||
try:
|
||||
await self.redis.set(self.cache_key(token_hash), payload, ex=self.settings.redis_binding_ttl_seconds)
|
||||
except Exception:
|
||||
@@ -336,7 +366,7 @@ class BindingService:
|
||||
)
|
||||
return series
|
||||
|
||||
async def _load_binding_from_cache(self, token_hash: str) -> tuple[BindingRecord | None, bool]:
|
||||
async def _load_binding_from_cache(self, token_hash: str, client_ip: str) -> tuple[BindingRecord | None, bool]:
|
||||
if self.redis is None:
|
||||
return None, False
|
||||
try:
|
||||
@@ -348,14 +378,18 @@ class BindingService:
|
||||
return None, True
|
||||
|
||||
data = json.loads(raw)
|
||||
allowed_ips = [str(item) for item in data.get("allowed_ips", [])]
|
||||
binding_mode = str(data.get("binding_mode", BINDING_MODE_SINGLE))
|
||||
return (
|
||||
BindingRecord(
|
||||
id=0,
|
||||
token_hash=token_hash,
|
||||
token_display="",
|
||||
bound_ip=data["bound_ip"],
|
||||
bound_ip=str(data.get("bound_ip", self.build_bound_ip_display(binding_mode, allowed_ips))),
|
||||
binding_mode=binding_mode,
|
||||
allowed_ips=allowed_ips,
|
||||
status=int(data["status"]),
|
||||
ip_matched=False,
|
||||
ip_matched=self.is_client_allowed(client_ip, binding_mode, allowed_ips),
|
||||
),
|
||||
True,
|
||||
)
|
||||
@@ -369,69 +403,33 @@ class BindingService:
|
||||
logger.warning("Failed to extend binding cache TTL.", extra={"token_hash": token_hash})
|
||||
|
||||
async def _load_binding_from_db(self, token_hash: str, client_ip: str) -> BindingRecord | None:
|
||||
query = text(
|
||||
"""
|
||||
SELECT
|
||||
id,
|
||||
token_hash,
|
||||
token_display,
|
||||
bound_ip::text AS bound_ip,
|
||||
status,
|
||||
CAST(:client_ip AS inet) << bound_ip AS ip_matched
|
||||
FROM token_bindings
|
||||
WHERE token_hash = :token_hash
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(query, {"token_hash": token_hash, "client_ip": client_ip})
|
||||
row = result.mappings().first()
|
||||
if row is None:
|
||||
binding = await session.scalar(select(TokenBinding).where(TokenBinding.token_hash == token_hash).limit(1))
|
||||
if binding is None:
|
||||
return None
|
||||
return BindingRecord(
|
||||
id=int(row["id"]),
|
||||
token_hash=str(row["token_hash"]),
|
||||
token_display=str(row["token_display"]),
|
||||
bound_ip=str(row["bound_ip"]),
|
||||
status=int(row["status"]),
|
||||
ip_matched=bool(row["ip_matched"]),
|
||||
)
|
||||
return self.to_binding_record(binding, client_ip)
|
||||
|
||||
async def _create_binding(self, token_hash: str, token_display: str, client_ip: str) -> BindingRecord | None:
|
||||
statement = text(
|
||||
"""
|
||||
INSERT INTO token_bindings (token_hash, token_display, bound_ip, status)
|
||||
VALUES (:token_hash, :token_display, CAST(:bound_ip AS cidr), :status)
|
||||
ON CONFLICT (token_hash) DO NOTHING
|
||||
RETURNING id, token_hash, token_display, bound_ip::text AS bound_ip, status
|
||||
"""
|
||||
)
|
||||
async with self.session_factory() as session:
|
||||
try:
|
||||
result = await session.execute(
|
||||
statement,
|
||||
{
|
||||
"token_hash": token_hash,
|
||||
"token_display": token_display,
|
||||
"bound_ip": client_ip,
|
||||
"status": STATUS_ACTIVE,
|
||||
},
|
||||
binding = TokenBinding(
|
||||
token_hash=token_hash,
|
||||
token_display=token_display,
|
||||
bound_ip=client_ip,
|
||||
binding_mode=BINDING_MODE_SINGLE,
|
||||
allowed_ips=[client_ip],
|
||||
status=STATUS_ACTIVE,
|
||||
)
|
||||
row = result.mappings().first()
|
||||
session.add(binding)
|
||||
await session.flush()
|
||||
await session.commit()
|
||||
except SQLAlchemyError:
|
||||
await session.refresh(binding)
|
||||
except SQLAlchemyError as exc:
|
||||
await session.rollback()
|
||||
if "duplicate key" in str(exc).lower() or "unique" in str(exc).lower():
|
||||
return None
|
||||
raise
|
||||
if row is None:
|
||||
return None
|
||||
return BindingRecord(
|
||||
id=int(row["id"]),
|
||||
token_hash=str(row["token_hash"]),
|
||||
token_display=str(row["token_display"]),
|
||||
bound_ip=str(row["bound_ip"]),
|
||||
status=int(row["status"]),
|
||||
ip_matched=True,
|
||||
)
|
||||
return self.to_binding_record(binding, client_ip)
|
||||
|
||||
def _handle_backend_failure(self, token_hash: str, token_display: str) -> BindingCheckResult:
|
||||
runtime_settings = self.runtime_settings_getter()
|
||||
|
||||
Reference in New Issue
Block a user