112 lines
4.2 KiB
Python
112 lines
4.2 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import logging
|
||
|
|
from urllib.parse import urlsplit
|
||
|
|
|
||
|
|
import httpx
|
||
|
|
from fastapi import APIRouter, Depends, Request
|
||
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||
|
|
|
||
|
|
from app.config import Settings
|
||
|
|
from app.core.ip_utils import extract_client_ip
|
||
|
|
from app.core.security import extract_bearer_token
|
||
|
|
from app.dependencies import get_alert_service, get_binding_service, get_settings
|
||
|
|
from app.services.alert_service import AlertService
|
||
|
|
from app.services.binding_service import BindingService
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
router = APIRouter()
|
||
|
|
CONTENT_LENGTH_HEADER = "content-length"
|
||
|
|
|
||
|
|
|
||
|
|
def build_upstream_headers(request: Request, downstream_url: str) -> list[tuple[str, str]]:
|
||
|
|
downstream_host = urlsplit(downstream_url).netloc
|
||
|
|
headers: list[tuple[str, str]] = []
|
||
|
|
for header_name, header_value in request.headers.items():
|
||
|
|
if header_name.lower() == "host":
|
||
|
|
continue
|
||
|
|
headers.append((header_name, header_value))
|
||
|
|
headers.append(("host", downstream_host))
|
||
|
|
return headers
|
||
|
|
|
||
|
|
|
||
|
|
def build_upstream_url(settings: Settings, request: Request) -> str:
|
||
|
|
return f"{settings.downstream_url}{request.url.path}"
|
||
|
|
|
||
|
|
|
||
|
|
def apply_downstream_headers(response: StreamingResponse, upstream_response: httpx.Response) -> None:
|
||
|
|
for header_name, header_value in upstream_response.headers.multi_items():
|
||
|
|
if header_name.lower() == CONTENT_LENGTH_HEADER:
|
||
|
|
continue
|
||
|
|
response.headers.append(header_name, header_value)
|
||
|
|
|
||
|
|
|
||
|
|
@router.api_route("/", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"], include_in_schema=False)
|
||
|
|
@router.api_route(
|
||
|
|
"/{path:path}",
|
||
|
|
methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS", "HEAD"],
|
||
|
|
include_in_schema=False,
|
||
|
|
)
|
||
|
|
async def reverse_proxy(
|
||
|
|
request: Request,
|
||
|
|
path: str = "",
|
||
|
|
settings: Settings = Depends(get_settings),
|
||
|
|
binding_service: BindingService = Depends(get_binding_service),
|
||
|
|
alert_service: AlertService = Depends(get_alert_service),
|
||
|
|
):
|
||
|
|
client_ip = extract_client_ip(request, settings)
|
||
|
|
token = extract_bearer_token(request.headers.get("authorization"))
|
||
|
|
|
||
|
|
if token:
|
||
|
|
binding_result = await binding_service.evaluate_token_binding(token, client_ip)
|
||
|
|
if binding_result.allowed:
|
||
|
|
await binding_service.increment_request_metric("allowed")
|
||
|
|
else:
|
||
|
|
await binding_service.increment_request_metric("intercepted" if binding_result.should_alert else None)
|
||
|
|
if binding_result.should_alert and binding_result.token_hash and binding_result.token_display and binding_result.bound_ip:
|
||
|
|
await alert_service.handle_intercept(
|
||
|
|
token_hash=binding_result.token_hash,
|
||
|
|
token_display=binding_result.token_display,
|
||
|
|
bound_ip=binding_result.bound_ip,
|
||
|
|
attempt_ip=client_ip,
|
||
|
|
)
|
||
|
|
return JSONResponse(
|
||
|
|
status_code=binding_result.status_code,
|
||
|
|
content={"detail": binding_result.detail},
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
await binding_service.increment_request_metric("allowed")
|
||
|
|
|
||
|
|
downstream_client: httpx.AsyncClient = request.app.state.downstream_client
|
||
|
|
upstream_url = build_upstream_url(settings, request)
|
||
|
|
upstream_headers = build_upstream_headers(request, settings.downstream_url)
|
||
|
|
|
||
|
|
try:
|
||
|
|
upstream_request = downstream_client.build_request(
|
||
|
|
request.method,
|
||
|
|
upstream_url,
|
||
|
|
params=request.query_params.multi_items(),
|
||
|
|
headers=upstream_headers,
|
||
|
|
content=request.stream(),
|
||
|
|
)
|
||
|
|
upstream_response = await downstream_client.send(upstream_request, stream=True)
|
||
|
|
except httpx.HTTPError as exc:
|
||
|
|
logger.exception("Failed to reach downstream service.")
|
||
|
|
return JSONResponse(status_code=502, content={"detail": f"Downstream request failed: {exc!s}"})
|
||
|
|
|
||
|
|
async def stream_response():
|
||
|
|
try:
|
||
|
|
async for chunk in upstream_response.aiter_raw():
|
||
|
|
yield chunk
|
||
|
|
finally:
|
||
|
|
await upstream_response.aclose()
|
||
|
|
|
||
|
|
response = StreamingResponse(
|
||
|
|
stream_response(),
|
||
|
|
status_code=upstream_response.status_code,
|
||
|
|
media_type=None,
|
||
|
|
)
|
||
|
|
apply_downstream_headers(response, upstream_response)
|
||
|
|
return response
|