Add multi-IP binding modes and deployment guide

This commit is contained in:
2026-03-04 15:30:13 +08:00
parent 4348ee799b
commit eed1acd454
12 changed files with 509 additions and 217 deletions

View File

@@ -28,6 +28,8 @@ def to_binding_item(binding: TokenBinding, binding_service: BindingService) -> B
id=binding.id,
token_display=binding.token_display,
bound_ip=str(binding.bound_ip),
binding_mode=binding.binding_mode,
allowed_ips=[str(item) for item in binding.allowed_ips],
status=binding.status,
status_label=binding_service.status_label(binding.status),
first_used_at=binding.first_used_at,
@@ -70,7 +72,13 @@ def log_admin_action(request: Request, settings: Settings, action: str, binding_
async def commit_binding_cache(binding: TokenBinding, binding_service: BindingService) -> None:
await binding_service.sync_binding_cache(binding.token_hash, str(binding.bound_ip), binding.status)
await binding_service.sync_binding_cache(
binding.token_hash,
str(binding.bound_ip),
binding.binding_mode,
[str(item) for item in binding.allowed_ips],
binding.status,
)
async def update_binding_status(
@@ -138,7 +146,9 @@ async def update_bound_ip(
binding_service: BindingService = Depends(get_binding_service),
):
binding = await get_binding_or_404(session, payload.id)
binding.bound_ip = payload.bound_ip
binding.binding_mode = payload.binding_mode
binding.allowed_ips = payload.allowed_ips
binding.bound_ip = binding_service.build_bound_ip_display(payload.binding_mode, payload.allowed_ips)
await session.commit()
await commit_binding_cache(binding, binding_service)
log_admin_action(request, settings, "update_ip", payload.id)

View File

@@ -76,7 +76,7 @@ async def build_recent_intercepts(session: AsyncSession) -> list[InterceptLogIte
InterceptLogItem(
id=item.id,
token_display=item.token_display,
bound_ip=str(item.bound_ip),
bound_ip=item.bound_ip,
attempt_ip=str(item.attempt_ip),
alerted=item.alerted,
intercepted_at=item.intercepted_at,

View File

@@ -38,7 +38,7 @@ def to_log_item(item: InterceptLog) -> InterceptLogItem:
return InterceptLogItem(
id=item.id,
token_display=item.token_display,
bound_ip=str(item.bound_ip),
bound_ip=item.bound_ip,
attempt_ip=str(item.attempt_ip),
alerted=item.alerted,
intercepted_at=item.intercepted_at,
@@ -47,13 +47,13 @@ def to_log_item(item: InterceptLog) -> InterceptLogItem:
def write_log_csv(buffer: io.StringIO, logs: list[InterceptLog]) -> None:
writer = csv.writer(buffer)
writer.writerow(["id", "token_display", "bound_ip", "attempt_ip", "alerted", "intercepted_at"])
writer.writerow(["id", "token_display", "binding_rule", "attempt_ip", "alerted", "intercepted_at"])
for item in logs:
writer.writerow(
[
item.id,
item.token_display,
str(item.bound_ip),
item.bound_ip,
str(item.attempt_ip),
item.alerted,
item.intercepted_at.isoformat(),

View File

@@ -14,7 +14,7 @@ from redis.asyncio import from_url as redis_from_url
from app.api import auth, bindings, dashboard, logs, settings as settings_api
from app.config import RUNTIME_SETTINGS_REDIS_KEY, RuntimeSettings, Settings, get_settings
from app.models import intercept_log, token_binding # noqa: F401
from app.models.db import close_db, get_session_factory, init_db
from app.models.db import close_db, ensure_schema_compatibility, get_session_factory, init_db
from app.proxy.handler import router as proxy_router
from app.services.alert_service import AlertService
from app.services.archive_service import ArchiveService
@@ -100,6 +100,7 @@ async def load_runtime_settings(redis: Redis | None, settings: Settings) -> Runt
async def lifespan(app: FastAPI):
settings = get_settings()
init_db(settings)
await ensure_schema_compatibility()
session_factory = get_session_factory()
redis: Redis | None = redis_from_url(

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
@@ -40,6 +41,31 @@ def get_session_factory() -> async_sessionmaker[AsyncSession]:
return _session_factory
async def ensure_schema_compatibility() -> None:
engine = get_engine()
statements = [
"DROP INDEX IF EXISTS idx_token_bindings_ip",
"ALTER TABLE token_bindings ALTER COLUMN bound_ip TYPE TEXT USING bound_ip::text",
"ALTER TABLE intercept_logs ALTER COLUMN bound_ip TYPE TEXT USING bound_ip::text",
"ALTER TABLE token_bindings ADD COLUMN IF NOT EXISTS binding_mode VARCHAR(16) DEFAULT 'single'",
"ALTER TABLE token_bindings ADD COLUMN IF NOT EXISTS allowed_ips JSONB DEFAULT '[]'::jsonb",
"UPDATE token_bindings SET binding_mode = 'single' WHERE binding_mode IS NULL OR binding_mode = ''",
"""
UPDATE token_bindings
SET allowed_ips = jsonb_build_array(bound_ip)
WHERE allowed_ips IS NULL OR allowed_ips = '[]'::jsonb
""",
"ALTER TABLE token_bindings ALTER COLUMN binding_mode SET NOT NULL",
"ALTER TABLE token_bindings ALTER COLUMN allowed_ips SET NOT NULL",
"ALTER TABLE token_bindings ALTER COLUMN binding_mode SET DEFAULT 'single'",
"ALTER TABLE token_bindings ALTER COLUMN allowed_ips SET DEFAULT '[]'::jsonb",
"CREATE INDEX IF NOT EXISTS idx_token_bindings_ip ON token_bindings(bound_ip)",
]
async with engine.begin() as connection:
for statement in statements:
await connection.execute(text(statement))
async def close_db() -> None:
global _engine, _session_factory
if _engine is not None:

View File

@@ -2,8 +2,8 @@ from __future__ import annotations
from datetime import datetime
from sqlalchemy import Boolean, DateTime, Index, String, func, text
from sqlalchemy.dialects.postgresql import CIDR, INET
from sqlalchemy import Boolean, DateTime, Index, String, Text, func, text
from sqlalchemy.dialects.postgresql import INET
from sqlalchemy.orm import Mapped, mapped_column
from app.models.db import Base
@@ -19,7 +19,7 @@ class InterceptLog(Base):
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
token_hash: Mapped[str] = mapped_column(String(64), nullable=False)
token_display: Mapped[str] = mapped_column(String(20), nullable=False)
bound_ip: Mapped[str] = mapped_column(CIDR, nullable=False)
bound_ip: Mapped[str] = mapped_column(Text, nullable=False)
attempt_ip: Mapped[str] = mapped_column(INET, nullable=False)
alerted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default=text("FALSE"))
intercepted_at: Mapped[datetime] = mapped_column(

View File

@@ -2,27 +2,42 @@ from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, Index, SmallInteger, String, func, text
from sqlalchemy.dialects.postgresql import CIDR
from sqlalchemy import DateTime, Index, SmallInteger, String, Text, func, text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from app.models.db import Base
STATUS_ACTIVE = 1
STATUS_BANNED = 2
BINDING_MODE_SINGLE = "single"
BINDING_MODE_MULTIPLE = "multiple"
BINDING_MODE_ALL = "all"
class TokenBinding(Base):
__tablename__ = "token_bindings"
__table_args__ = (
Index("idx_token_bindings_hash", "token_hash"),
Index("idx_token_bindings_ip", "bound_ip", postgresql_using="gist", postgresql_ops={"bound_ip": "inet_ops"}),
Index("idx_token_bindings_ip", "bound_ip"),
)
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False)
token_display: Mapped[str] = mapped_column(String(20), nullable=False)
bound_ip: Mapped[str] = mapped_column(CIDR, nullable=False)
bound_ip: Mapped[str] = mapped_column(Text, nullable=False)
binding_mode: Mapped[str] = mapped_column(
String(16),
nullable=False,
default=BINDING_MODE_SINGLE,
server_default=text("'single'"),
)
allowed_ips: Mapped[list[str]] = mapped_column(
JSONB,
nullable=False,
default=list,
server_default=text("'[]'::jsonb"),
)
status: Mapped[int] = mapped_column(
SmallInteger,
nullable=False,

View File

@@ -1,8 +1,11 @@
from __future__ import annotations
from datetime import datetime
from ipaddress import ip_address, ip_network
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
from app.models.token_binding import BINDING_MODE_ALL, BINDING_MODE_MULTIPLE, BINDING_MODE_SINGLE
class BindingItem(BaseModel):
@@ -11,6 +14,8 @@ class BindingItem(BaseModel):
id: int
token_display: str
bound_ip: str
binding_mode: str
allowed_ips: list[str]
status: int
status_label: str
first_used_at: datetime
@@ -31,12 +36,32 @@ class BindingActionRequest(BaseModel):
class BindingIPUpdateRequest(BaseModel):
id: int = Field(gt=0)
bound_ip: str = Field(min_length=3, max_length=64)
binding_mode: str = Field(default=BINDING_MODE_SINGLE)
allowed_ips: list[str] = Field(default_factory=list)
@field_validator("bound_ip")
@classmethod
def validate_bound_ip(cls, value: str) -> str:
from ipaddress import ip_network
@model_validator(mode="after")
def validate_binding_rule(self):
allowed_ips = [item.strip() for item in self.allowed_ips if item.strip()]
ip_network(value, strict=False)
return value
if self.binding_mode == BINDING_MODE_ALL:
self.allowed_ips = []
return self
if self.binding_mode == BINDING_MODE_SINGLE:
if len(allowed_ips) != 1:
raise ValueError("Single binding mode requires exactly one IP or CIDR.")
ip_network(allowed_ips[0], strict=False)
self.allowed_ips = allowed_ips
return self
if self.binding_mode == BINDING_MODE_MULTIPLE:
if not allowed_ips:
raise ValueError("Multiple binding mode requires at least one IP.")
normalized: list[str] = []
for item in allowed_ips:
ip_address(item)
normalized.append(item)
self.allowed_ips = normalized
return self
raise ValueError("Unsupported binding mode.")

View File

@@ -5,18 +5,25 @@ import json
import logging
import time
from dataclasses import dataclass
from datetime import UTC, date, timedelta
from datetime import date, timedelta
from typing import Callable
from redis.asyncio import Redis
from sqlalchemy import func, select, text, update
from sqlalchemy import func, select, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.config import RuntimeSettings, Settings
from app.core.ip_utils import is_ip_in_network
from app.core.security import hash_token, mask_token
from app.models.token_binding import STATUS_ACTIVE, STATUS_BANNED, TokenBinding
from app.models.token_binding import (
BINDING_MODE_ALL,
BINDING_MODE_MULTIPLE,
BINDING_MODE_SINGLE,
STATUS_ACTIVE,
STATUS_BANNED,
TokenBinding,
)
logger = logging.getLogger(__name__)
@@ -27,6 +34,8 @@ class BindingRecord:
token_hash: str
token_display: str
bound_ip: str
binding_mode: str
allowed_ips: list[str]
status: int
ip_matched: bool
@@ -104,42 +113,101 @@ class BindingService:
def metrics_key(self, target_date: date) -> str:
return f"sentinel:metrics:{target_date.isoformat()}"
def build_bound_ip_display(self, binding_mode: str, allowed_ips: list[str]) -> str:
if binding_mode == BINDING_MODE_ALL:
return "ALL"
if not allowed_ips:
return "-"
if binding_mode == BINDING_MODE_MULTIPLE:
return ", ".join(allowed_ips)
return allowed_ips[0]
def is_client_allowed(self, client_ip: str, binding_mode: str, allowed_ips: list[str]) -> bool:
if binding_mode == BINDING_MODE_ALL:
return True
return any(is_ip_in_network(client_ip, item) for item in allowed_ips)
def to_binding_record(self, binding: TokenBinding, client_ip: str) -> BindingRecord:
allowed_ips = [str(item) for item in binding.allowed_ips]
binding_mode = binding.binding_mode or BINDING_MODE_SINGLE
return BindingRecord(
id=binding.id,
token_hash=binding.token_hash,
token_display=binding.token_display,
bound_ip=binding.bound_ip,
binding_mode=binding_mode,
allowed_ips=allowed_ips,
status=binding.status,
ip_matched=self.is_client_allowed(client_ip, binding_mode, allowed_ips),
)
def denied_result(
self,
token_hash: str,
token_display: str,
bound_ip: str,
detail: str,
*,
should_alert: bool = True,
status_code: int = 403,
) -> BindingCheckResult:
return BindingCheckResult(
allowed=False,
status_code=status_code,
detail=detail,
token_hash=token_hash,
token_display=token_display,
bound_ip=bound_ip,
should_alert=should_alert,
)
def allowed_result(
self,
token_hash: str,
token_display: str,
bound_ip: str,
detail: str,
*,
newly_bound: bool = False,
) -> BindingCheckResult:
return BindingCheckResult(
allowed=True,
status_code=200,
detail=detail,
token_hash=token_hash,
token_display=token_display,
bound_ip=bound_ip,
newly_bound=newly_bound,
)
def evaluate_existing_record(
self,
record: BindingRecord,
token_hash: str,
token_display: str,
detail: str,
) -> BindingCheckResult:
if record.status == STATUS_BANNED:
return self.denied_result(token_hash, token_display, record.bound_ip, "Token is banned.")
if record.ip_matched:
self.record_last_used(token_hash)
return self.allowed_result(token_hash, token_display, record.bound_ip, detail)
return self.denied_result(
token_hash,
token_display,
record.bound_ip,
"Client IP does not match the allowed binding rule.",
)
async def evaluate_token_binding(self, token: str, client_ip: str) -> BindingCheckResult:
token_hash = hash_token(token, self.settings.sentinel_hmac_secret)
token_display = mask_token(token)
cache_hit, cache_available = await self._load_binding_from_cache(token_hash)
cache_hit, cache_available = await self._load_binding_from_cache(token_hash, client_ip)
if cache_hit is not None:
if cache_hit.status == STATUS_BANNED:
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Token is banned.",
token_hash=token_hash,
token_display=token_display,
bound_ip=cache_hit.bound_ip,
should_alert=True,
)
if is_ip_in_network(client_ip, cache_hit.bound_ip):
if cache_hit.ip_matched:
await self._touch_cache(token_hash)
self.record_last_used(token_hash)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed from cache.",
token_hash=token_hash,
token_display=token_display,
bound_ip=cache_hit.bound_ip,
)
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Client IP does not match the bound CIDR.",
token_hash=token_hash,
token_display=token_display,
bound_ip=cache_hit.bound_ip,
should_alert=True,
)
return self.evaluate_existing_record(cache_hit, token_hash, token_display, "Allowed from cache.")
if not cache_available:
logger.warning("Redis is unavailable. Falling back to PostgreSQL for token binding.")
@@ -159,36 +227,8 @@ class BindingService:
return self._handle_backend_failure(token_hash, token_display)
if record is not None:
await self.sync_binding_cache(record.token_hash, record.bound_ip, record.status)
if record.status == STATUS_BANNED:
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Token is banned.",
token_hash=token_hash,
token_display=token_display,
bound_ip=record.bound_ip,
should_alert=True,
)
if record.ip_matched:
self.record_last_used(token_hash)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed from PostgreSQL.",
token_hash=token_hash,
token_display=token_display,
bound_ip=record.bound_ip,
)
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Client IP does not match the bound CIDR.",
token_hash=token_hash,
token_display=token_display,
bound_ip=record.bound_ip,
should_alert=True,
)
await self.sync_binding_cache(record.token_hash, record.bound_ip, record.binding_mode, record.allowed_ips, record.status)
return self.evaluate_existing_record(record, token_hash, token_display, "Allowed from PostgreSQL.")
try:
created = await self._create_binding(token_hash, token_display, client_ip)
@@ -202,52 +242,42 @@ class BindingService:
return self._handle_backend_failure(token_hash, token_display)
if existing is None:
return self._handle_backend_failure(token_hash, token_display)
await self.sync_binding_cache(existing.token_hash, existing.bound_ip, existing.status)
if existing.status == STATUS_BANNED:
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Token is banned.",
token_hash=token_hash,
token_display=token_display,
bound_ip=existing.bound_ip,
should_alert=True,
)
if existing.ip_matched:
self.record_last_used(token_hash)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="Allowed after concurrent bind resolution.",
token_hash=token_hash,
token_display=token_display,
bound_ip=existing.bound_ip,
)
return BindingCheckResult(
allowed=False,
status_code=403,
detail="Client IP does not match the bound CIDR.",
token_hash=token_hash,
token_display=token_display,
bound_ip=existing.bound_ip,
should_alert=True,
await self.sync_binding_cache(
existing.token_hash,
existing.bound_ip,
existing.binding_mode,
existing.allowed_ips,
existing.status,
)
return self.evaluate_existing_record(existing, token_hash, token_display, "Allowed after concurrent bind resolution.")
await self.sync_binding_cache(created.token_hash, created.bound_ip, created.status)
return BindingCheckResult(
allowed=True,
status_code=200,
detail="First-use bind created.",
token_hash=token_hash,
token_display=token_display,
bound_ip=created.bound_ip,
newly_bound=True,
await self.sync_binding_cache(
created.token_hash,
created.bound_ip,
created.binding_mode,
created.allowed_ips,
created.status,
)
return self.allowed_result(token_hash, token_display, created.bound_ip, "First-use bind created.", newly_bound=True)
async def sync_binding_cache(self, token_hash: str, bound_ip: str, status_code: int) -> None:
async def sync_binding_cache(
self,
token_hash: str,
bound_ip: str,
binding_mode: str,
allowed_ips: list[str],
status_code: int,
) -> None:
if self.redis is None:
return
payload = json.dumps({"bound_ip": bound_ip, "status": status_code})
payload = json.dumps(
{
"bound_ip": bound_ip,
"binding_mode": binding_mode,
"allowed_ips": allowed_ips,
"status": status_code,
}
)
try:
await self.redis.set(self.cache_key(token_hash), payload, ex=self.settings.redis_binding_ttl_seconds)
except Exception:
@@ -336,7 +366,7 @@ class BindingService:
)
return series
async def _load_binding_from_cache(self, token_hash: str) -> tuple[BindingRecord | None, bool]:
async def _load_binding_from_cache(self, token_hash: str, client_ip: str) -> tuple[BindingRecord | None, bool]:
if self.redis is None:
return None, False
try:
@@ -348,14 +378,18 @@ class BindingService:
return None, True
data = json.loads(raw)
allowed_ips = [str(item) for item in data.get("allowed_ips", [])]
binding_mode = str(data.get("binding_mode", BINDING_MODE_SINGLE))
return (
BindingRecord(
id=0,
token_hash=token_hash,
token_display="",
bound_ip=data["bound_ip"],
bound_ip=str(data.get("bound_ip", self.build_bound_ip_display(binding_mode, allowed_ips))),
binding_mode=binding_mode,
allowed_ips=allowed_ips,
status=int(data["status"]),
ip_matched=False,
ip_matched=self.is_client_allowed(client_ip, binding_mode, allowed_ips),
),
True,
)
@@ -369,69 +403,33 @@ class BindingService:
logger.warning("Failed to extend binding cache TTL.", extra={"token_hash": token_hash})
async def _load_binding_from_db(self, token_hash: str, client_ip: str) -> BindingRecord | None:
query = text(
"""
SELECT
id,
token_hash,
token_display,
bound_ip::text AS bound_ip,
status,
CAST(:client_ip AS inet) << bound_ip AS ip_matched
FROM token_bindings
WHERE token_hash = :token_hash
LIMIT 1
"""
)
async with self.session_factory() as session:
result = await session.execute(query, {"token_hash": token_hash, "client_ip": client_ip})
row = result.mappings().first()
if row is None:
binding = await session.scalar(select(TokenBinding).where(TokenBinding.token_hash == token_hash).limit(1))
if binding is None:
return None
return BindingRecord(
id=int(row["id"]),
token_hash=str(row["token_hash"]),
token_display=str(row["token_display"]),
bound_ip=str(row["bound_ip"]),
status=int(row["status"]),
ip_matched=bool(row["ip_matched"]),
)
return self.to_binding_record(binding, client_ip)
async def _create_binding(self, token_hash: str, token_display: str, client_ip: str) -> BindingRecord | None:
statement = text(
"""
INSERT INTO token_bindings (token_hash, token_display, bound_ip, status)
VALUES (:token_hash, :token_display, CAST(:bound_ip AS cidr), :status)
ON CONFLICT (token_hash) DO NOTHING
RETURNING id, token_hash, token_display, bound_ip::text AS bound_ip, status
"""
)
async with self.session_factory() as session:
try:
result = await session.execute(
statement,
{
"token_hash": token_hash,
"token_display": token_display,
"bound_ip": client_ip,
"status": STATUS_ACTIVE,
},
binding = TokenBinding(
token_hash=token_hash,
token_display=token_display,
bound_ip=client_ip,
binding_mode=BINDING_MODE_SINGLE,
allowed_ips=[client_ip],
status=STATUS_ACTIVE,
)
row = result.mappings().first()
session.add(binding)
await session.flush()
await session.commit()
except SQLAlchemyError:
await session.refresh(binding)
except SQLAlchemyError as exc:
await session.rollback()
if "duplicate key" in str(exc).lower() or "unique" in str(exc).lower():
return None
raise
if row is None:
return None
return BindingRecord(
id=int(row["id"]),
token_hash=str(row["token_hash"]),
token_display=str(row["token_display"]),
bound_ip=str(row["bound_ip"]),
status=int(row["status"]),
ip_matched=True,
)
return self.to_binding_record(binding, client_ip)
def _handle_backend_failure(self, token_hash: str, token_display: str) -> BindingCheckResult:
runtime_settings = self.runtime_settings_getter()