Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
main.py 13.73 KiB
import asyncio
import logging.config
from datetime import datetime
from typing import Any, Literal, Optional, Type, TypeVar
from uuid import UUID

import httpx
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, ConfigDict, Field

from src.common import resolve
from src.logging import LOGGING_CONFIG
from src.settings import PromConfig
from src.structures import TrafficStats

logging.config.dictConfig(LOGGING_CONFIG)
logger = logging.getLogger(__name__)


T = TypeVar("T")


async def _query_instant(
    client: httpx.AsyncClient, cfg: PromConfig, query: str, dtype: Type[T] = str
) -> tuple[Optional[dict], Optional[T]]:
    try:
        response = await client.get(cfg.query_url, params={"query": query}, timeout=10)
        response.raise_for_status()

        data = response.json()
        if data.get("status") == "success":
            results = data.get("data", {}).get("result", [])
            if results:
                # Safely extract and cast the value if available
                first_result = results[0]
                metric = first_result.get("metric", {})
                value_list = first_result.get("value", [])
                if isinstance(value_list, list) and len(value_list) > 1:
                    return metric, dtype(value_list[1])

        return None, None

    except (KeyError, ValueError) as parse_error:
        logger.error("Error parsing response data:", exc_info=True)
        raise HTTPException(status_code=500, detail=f"Error parsing Prometheus data: {parse_error}") from parse_error

    except httpx.RequestError as request_error:
        logger.error("HTTP request error:", exc_info=True)
        raise HTTPException(status_code=503, detail=f"HTTP request error: {request_error}") from request_error

    except httpx.HTTPStatusError as status_error:
        logger.error(
            f"HTTP status error: {status_error.response.status_code} - {status_error.response.text}", exc_info=True
        )
        raise HTTPException(status_code=503, detail=f"HTTP status error: {status_error}") from status_error


async def _query_series(client: httpx.AsyncClient, cfg: PromConfig, query: str) -> list[dict[str, Any]]:
    try:
        response = await client.get(cfg.query_url, params={"query": query}, timeout=10)
        response.raise_for_status()

        data = response.json()
        if data.get("status") == "success":
            return data.get("data", []).get("result", [])

        return []

    except (AttributeError, KeyError, ValueError) as parse_error:
        logger.error("Error parsing series response data:", exc_info=True)
        raise HTTPException(status_code=500, detail="Error parsing Prometheus series data.") from parse_error

    except httpx.RequestError as request_error:
        logger.error(f"HTTP request error: {request_error}", exc_info=True)
        raise HTTPException(status_code=503, detail=f"HTTP request error: {request_error}") from request_error

    except httpx.HTTPStatusError as status_error:
        logger.error(
            f"HTTP status error: {status_error.response.status_code} - {status_error.response.text}", exc_info=True
        )
        raise HTTPException(status_code=503, detail=f"HTTP status error: {status_error}") from status_error


async def _get_pod_inbound_traffic_bps(
    client: httpx.AsyncClient, pod_name: str, timestamp: float, cfg: PromConfig
) -> Optional[float]:
    template = '(sum(rate(tcp_read_bytes_total{{pod="{pod}"}}[15m] @ {ts})))'
    query = template.format(pod=pod_name, ts=timestamp)
    _, value = await _query_instant(client, cfg, query, float)
    return value


async def _get_pod_outbound_traffic_bps(
    client: httpx.AsyncClient, pod_name: str, timestamp: float, cfg: PromConfig
) -> Optional[float]:
    template = '(sum(rate(tcp_write_bytes_total{{pod="{pod}"}}[15m] @ {ts})))'
    query = template.format(pod=pod_name, ts=timestamp)
    _, value = await _query_instant(client, cfg, query, float)
    return value


async def _get_pod_traffic_per_link_bytes_per_second(
    client: httpx.AsyncClient, pod_name: str, timestamp: float, cfg: PromConfig
) -> dict[str, dict[str, list[float]]]:
    traffic_bytes: dict[str, dict[str, list[float]]] = {}

    # A counter of the total number of received bytes.
    rx_template = '(sum by (pod, dst_pod) (rate(tcp_read_bytes_total{{pod="{pod}"}}[15m] @ {ts})))'
    rx_query = rx_template.format(pod=pod_name, ts=timestamp)
    rx_task = _query_series(client, cfg, rx_query)

    # A counter of the total number of sent bytes.
    tx_template = '(sum by (pod, dst_pod) (rate(tcp_write_bytes_total{{pod="{pod}"}}[15m] @ {ts})))'
    tx_query = tx_template.format(pod=pod_name, ts=timestamp)
    tx_task = _query_series(client, cfg, tx_query)

    try:
        stats = await asyncio.gather(rx_task, tx_task)

        for idx, results in enumerate(stats):
            for result in results:
                metric = result.get("metric", {})
                value = result.get("value", [])

                # safely extract source and destination pods
                src = metric.get("pod", "")
                dst = metric.get("dst_pod", "")

                if len(value) > 1:
                    rate = float(value[1])

                    # initialize nested structures if not present
                    traffic_bytes.setdefault(src, {}).setdefault(dst, [0.0, 0.0])

                    # increment the appropriate rate (rx_rate or tx_rate)
                    traffic_bytes[src][dst][idx] += rate

        return traffic_bytes

    except Exception as default_error:
        logger.error("Error parsing response data:", exc_info=True)
        raise HTTPException(status_code=500, detail="Error parsing data.") from default_error


async def _get_pod_inbound_traffic_rate(
    client: httpx.AsyncClient, pod_name: str, timestamp: float, cfg: PromConfig
) -> Optional[float]:
    template = 'sum(rate(request_total{{pod="{pod}"}}[15m] @ {ts}))'
    query = template.format(pod=pod_name, ts=timestamp)

    _, value = await _query_instant(client, cfg, query, float)
    return value


async def _get_pod_outbound_traffic_rate(
    client: httpx.AsyncClient, pod_name: str, timestamp: float, cfg: PromConfig
) -> Optional[float]:
    template = 'sum(rate(response_total{{pod="{pod}"}}[15m] @ {ts}))'
    query = template.format(pod=pod_name, ts=timestamp)

    _, value = await _query_instant(client, cfg, query, float)
    return value


async def _get_outbound_traffic_rate_by_status_code(
    client: httpx.AsyncClient, pod_name: str, timestamp: float, cfg: PromConfig
) -> dict[int, float]:
    template = 'sum by (status_code) (rate(response_total{{pod="{pod}"}}[15m] @ {ts}))'
    query = template.format(pod=pod_name, ts=timestamp)

    try:
        traffic_res_rate_by_code: dict[int, float] = {}
        results = await _query_series(client, cfg, query)

        for result in results:
            metric = result.get("metric", {})
            value = result.get("value", [])

            # Extract and cast status code and rate values
            status_code = metric.get("status_code")
            if status_code is not None:
                if len(value) > 1:
                    status_code_rate = float(value[1])
                    if status_code_rate > 0:
                        traffic_res_rate_by_code[status_code] = status_code_rate

        return traffic_res_rate_by_code

    except (AttributeError, KeyError, ValueError) as parse_error:
        logger.error(f"Error parsing series response data: {parse_error}")
        raise HTTPException(
            status_code=503, detail=f"Error parsing Prometheus series data: {parse_error}"
        ) from parse_error


async def _get_pod_outbound_traffic_latency(
    client: httpx.AsyncClient, pod_name: str, timestamp: float, cfg: PromConfig
) -> dict[str, float]:
    template = 'histogram_quantile({phi}, sum(rate(response_latency_ms_bucket{{pod="{pod}"}}[15m] @ {ts})) by (le))'
    queries = {
        "p99": template.format(pod=pod_name, ts=timestamp, phi=0.99),
        "p95": template.format(pod=pod_name, ts=timestamp, phi=0.95),
        "p75": template.format(pod=pod_name, ts=timestamp, phi=0.75),
        "p50": template.format(pod=pod_name, ts=timestamp, phi=0.50),
    }

    async def aquery(query):
        _, value = await _query_instant(client, cfg, query, float)
        return value

    try:
        tasks = {phi: aquery(query) for phi, query in queries.items()}
        futures = await asyncio.gather(*tasks.values())
        results = {phi: value for phi, value in zip(tasks.keys(), futures, strict=True) if value is not None}
        return results

    except Exception as ex:
        logger.error(f"Request error while fetching pod traffic response stats:\n{ex}")
        raise HTTPException(status_code=503, detail=f"Request error while fetching data: {ex}") from ex


app = FastAPI()


@app.get("/health", tags=["Health"])
@app.get("/healthz", tags=["Health"])
@app.get("/liveness", tags=["Health", "Liveness"])
@app.get("/readyness", tags=["Health", "Readyness"])
async def health_check():
    return {"status": "healthy"}


class PodInfo(BaseModel):
    model_config = ConfigDict(populate_by_name=True, extra="ignore")

    cluster: str
    region: str
    node: str
    tenant: str

    namespace: str
    # deployment: Optional[str]
    service: str
    job: str
    pod: str

    # name: str
    uid: str
    workload_uid: Optional[str] = Field(None, validation_alias="label_nemo_eu_workload")


class PodInfoExtended(PodInfo):
    current_state: Literal["Pending", "Running", "Succeeded", "Failed", "Unknown"]
    status_changes_1h: int
    traffic_stats: TrafficStats
    replicas: int


@app.get("/api/v1/workloads")
async def get_workloads_list():
    prom_config = PromConfig()
    query = 'kube_pod_labels{label_nemo_eu_workload=~".*"}'

    async with httpx.AsyncClient() as client:
        results = await _query_series(client, prom_config, query)

    workloads: dict[str, dict[str, PodInfo]] = {}
    for result in results:
        pod = result.get("metric", {})
        # TODO: We don't filter, until Victor test everything out.
        workload_uid = pod.get("label_nemo_eu_workload", "none")
        if workload_uid is not None:
            print(result)
            workloads.setdefault(workload_uid, {}).setdefault("pods", []).append(PodInfo.model_validate(pod))

    return workloads


async def get_pod_by_name_or_uid(client: httpx.AsyncClient, pod_name_or_uid: str, cfg: PromConfig) -> dict:
    try:
        uid = UUID(pod_name_or_uid)
        query = f'kube_pod_labels{{uid="{uid}"}}'
        pod, _ = await _query_instant(client, cfg, query)
        return pod

    except ValueError:
        pass

    pod_name = pod_name_or_uid
    query = f'kube_pod_labels{{pod="{pod_name}"}}'
    pod, _ = await _query_instant(client, cfg, query)
    return pod


@app.get("/api/v1/pods/{pod_name_or_uid}")
async def get_pod_details(pod_name_or_uid: str):
    cfg = PromConfig()

    # Gather traffic stats
    timestamp = datetime.now().timestamp()

    async def get_current_state(client, uid) -> Literal["Pending", "Running", "Succeeded", "Failed", "Unknown"]:
        query = f'kube_pod_status_phase{{uid="{uid}"}} > 0'
        metric, _ = await _query_instant(client, cfg, query)
        return metric["phase"]

    async def get_status_changes(client, uid) -> int:
        query = f'changes(kube_pod_status_phase{{uid="{uid}"}}[1h] @ {timestamp})'
        _, value = await _query_instant(client, cfg, query)
        return int(value)

    async def get_num_replicas(client, pod):
        # Step 1: Get the ReplicaSet owner of the pod
        query = f'kube_pod_owner{{namespace="{pod["namespace"]}", pod="{pod["pod"]}"}}'
        metric, _ = await _query_instant(client, cfg, query)
        replicaset_name = metric["owner_name"]

        # Step 2: Get the Deployment owner of the ReplicaSet
        query = f'kube_replicaset_owner{{namespace="{pod["namespace"]}", replicaset="{replicaset_name}"}}'
        metric, _ = await _query_instant(client, cfg, query)
        deployment_name = metric["owner_name"]

        # Step 3: Figure out how many deployments are there
        query = f'kube_deployment_spec_replicas{{namespace="{pod["namespace"]}", deployment="{deployment_name}"}}'
        _, value = await _query_instant(client, cfg, query)
        return int(value)

    async with httpx.AsyncClient() as client:
        pod = await get_pod_by_name_or_uid(client, pod_name_or_uid, cfg)
        uid = pod.get("uid")
        pod_name = pod.get("pod")

        # obtain last valid state
        pod["current_state"] = get_current_state(client, uid)

        # obtain number of status changes in last hour
        pod["status_changes_1h"] = get_status_changes(client, uid)

        # get mumber of replicas
        pod["replicas"] = get_num_replicas(client, pod)

        # Obtain traffic data
        pod["traffic_stats"] = {
            "req_rate": _get_pod_inbound_traffic_rate(client, pod_name, timestamp, cfg),
            "req_bytes": _get_pod_outbound_traffic_bps(client, pod_name, timestamp, cfg),
            "res_rate": _get_pod_outbound_traffic_rate(client, pod_name, timestamp, cfg),
            "res_bytes": _get_pod_inbound_traffic_bps(client, pod_name, timestamp, cfg),
            "res_rate_by_code": _get_outbound_traffic_rate_by_status_code(client, pod_name, timestamp, cfg),
            "res_time_quantiles_ms": _get_pod_outbound_traffic_latency(client, pod_name, timestamp, cfg),
        }

        pod = await resolve(pod)

    stats = PodInfoExtended.model_validate(pod)

    return stats


@app.get("/api/v1/pods/{pod_name_or_uid}/traffic")
async def get_traffic_stats(pod_name_or_uid: str):
    cfg = PromConfig()
    timestamp = datetime.now().timestamp()

    async with httpx.AsyncClient() as client:
        pod = await get_pod_by_name_or_uid(client, pod_name_or_uid, cfg)
        results = await _get_pod_traffic_per_link_bytes_per_second(client, pod["pod"], timestamp, cfg)

    return results