#!/usr/bin/env bash
# WireGuard ECMP latency-aware route manager — FAST FAILOVER
# Overlay: $WIREGUARD_SUBNET
# Client src: $WIREGUARD_PRIVATE_IP/32

set -euo pipefail

LAYEROPS_HOME_DIR=/opt/layerops
LAYEROPS_BIN_DIR=${LAYEROPS_HOME_DIR}/bin
WIREGUARD_QUICK_RELOAD_PATH=${LAYEROPS_BIN_DIR}/wg-quick-up-reload
WIREGUARD_ROUTES_MTU=1200

OVERLAY="$WIREGUARD_SUBNET"
SRC_IP="$WIREGUARD_PRIVATE_IP"

# Interfaces
GW_INTERFACES=""
IF1=""
IF2=""

# Probing
PING_TIMEOUT_S=1      # per-probe timeout
SAMPLES=2             # pings per cycle (lower = faster)
SLEEP_SEC=1           # loop period (lower = faster)
PROBE1=""
PROBE2=""

# Latency weighting
DIFF_MS=10
RATIO_STRONG=1.5

# Hysteresis
FAILS_DOWN_FAST=1     # FAST FAIL: 1 failed probe => DOWN now
SUCCESSES_UP=3        # need 3 good cycles to restore link

# Handshake (advisory only)
HANDSHAKE_AGE_DOWN=120  # if very old AND probes fail -> keep DOWN

# Anti-flap: minimum time between *weight changes* when both links up
MIN_DWELL_SEC=15
last_change_ts=0

# Internal state
ok1=0; fail1=0; up1=true
ok2=0; fail2=0; up2=true

log(){ echo "[wg-ecmp] $(date +'%F %T') $*"; }

get_handshake_age() {
  local ifc="$1"
  { wg show "$ifc" latest-handshakes 2>/dev/null || echo "_ 0"; } | awk '{print $2}' | {
    read ts || { echo "na"; exit; }
    if [[ "$ts" == "0" ]]; then echo "na"; else
      now=$(date +%s); echo $(( now - ts )); fi
  }
}

probe_rtt() {
  local ifc="$1" dst="$2"
  local ok=0 total=0 out rtt
  for _ in $(seq 1 "$SAMPLES"); do
    out=$(ping -n -I "$ifc" -c1 -W "$PING_TIMEOUT_S" "$dst" 2>/dev/null || true)
    rtt=$(sed -n 's/.*time=\([0-9.]\+\) ms.*/\1/p' <<<"$out" | head -n1)
    if [[ -n "${rtt:-}" ]]; then
      ok=$((ok+1))
      total=$(awk -v a="$total" -v b="$rtt" 'BEGIN{printf "%.3f", a+b}')
    fi
    sleep 0.02
  done
  if [[ $ok -eq 0 ]]; then echo "down"; else
    awk -v t="$total" -v o="$ok" 'BEGIN{printf "%.3f", t/o}'; fi
}

apply_ecmp_weights() {
  local w1="$1" w2="$2"
  if [[ $w2 -eq 0 ]]; then
    # Gateway $IF2 is down: use gw1
    ip route replace "$OVERLAY" proto static scope global src "$SRC_IP" \
      mtu $WIREGUARD_ROUTES_MTU \
      nexthop dev "$IF1" weight 256
  elif [[ $w1 -eq 0 ]]; then
    # Gateway $IF1 is down: use gw2
    ip route replace "$OVERLAY" proto static scope global src "$SRC_IP" \
      mtu $WIREGUARD_ROUTES_MTU \
      nexthop dev "$IF2" weight 256
  else
    # Both gateways are up => use both with respective weights
    ip route replace "$OVERLAY" proto static scope global src "$SRC_IP" \
      mtu $WIREGUARD_ROUTES_MTU \
      nexthop dev "$IF1" weight "$w1" \
      nexthop dev "$IF2" weight "$w2"
  fi
}

main() {
  local last_state=""

  GW_INTERFACES=$(for file in $(cd /etc/wireguard/ && ls wg.*.conf); do   echo "${file%.*}"; done)
  for IF in $GW_INTERFACES
  do
    $WIREGUARD_QUICK_RELOAD_PATH $IF
  done

  while true; do
    GW_INTERFACES=$(for file in $(cd /etc/wireguard/ && ls wg.*.conf); do   echo "${file%.*}"; done)
    IF1=$(echo $GW_INTERFACES | awk '{print $1}')
    IF2=$(echo $GW_INTERFACES | awk '{print $2}')

    [ ! -z "$IF1" ] && PROBE1=$(for ip in $(wg show $IF1 allowed-ips 2> /dev/null); do [[ "$ip" == */32 ]] && echo ${ip%/32} || echo -n ""; done)
    [ ! -z "$IF2" ] && PROBE2=$(for ip in $(wg show $IF2 allowed-ips 2> /dev/null); do [[ "$ip" == */32 ]] && echo ${ip%/32} || echo -n ""; done)

    # Admin-down shortcuts
    ip link show "$IF1" up >/dev/null 2>&1 || { up1=false; fail1=$FAILS_DOWN_FAST; }
    ip link show "$IF2" up >/dev/null 2>&1 || { up2=false; fail2=$FAILS_DOWN_FAST; }
    [ -z "$PROBE1" ] && { up1=false; fail1=$FAILS_DOWN_FAST; }
    [ -z "$PROBE2" ] && { up2=false; fail2=$FAILS_DOWN_FAST; }

    # Telemetry
    local h1 h2 r1 r2
    h1=$(get_handshake_age "$IF1")
    h2=$(get_handshake_age "$IF2")
    r1=$([ -z "$PROBE1" ] && echo "down" || probe_rtt "$IF1" "$PROBE1")
    r2=$([ -z "$PROBE2" ] && echo "down" || probe_rtt "$IF2" "$PROBE2")

    # FAST FAIL logic: 1 fail => DOWN immediately
    if [[ "$r1" == "down" ]]; then
      fail1=$((fail1+1)); ok1=0
      (( fail1 >= FAILS_DOWN_FAST )) && up1=false
    else
      ok1=$((ok1+1)); fail1=0
      (( ok1 >= SUCCESSES_UP )) && up1=true
    fi
    if [[ "$r1" == "down" && "$h1" != "na" && "$h1" -gt $HANDSHAKE_AGE_DOWN ]]; then
      up1=false
    fi

    if [[ "$r2" == "down" ]]; then
      fail2=$((fail2+1)); ok2=0
      (( fail2 >= FAILS_DOWN_FAST )) && up2=false
    else
      ok2=$((ok2+1)); fail2=0
      (( ok2 >= SUCCESSES_UP )) && up2=true
    fi
    if [[ "$r2" == "down" && "$h2" != "na" && "$h2" -gt $HANDSHAKE_AGE_DOWN ]]; then
      up2=false
    fi

    local w1=0 w2=0 state=""
    if ! $up1 && ! $up2; then
      state="both_down r1=$r1 r2=$r2 h1=$h1 h2=$h2"
    elif ! $up1; then
      state="s1_down r1=$r1 h1=$h1 | s2_up r2=$r2 h2=$h2"; w1=0; w2=256
    elif ! $up2; then
      state="s2_down r2=$r2 h2=$h2 | s1_up r1=$r1 h1=$h1"; w1=256; w2=0
    else
      # Both up: weight by latency (with dwell to avoid flapping)
      local d ratio
      d=$(awk -v a="$r1" -v b="$r2" 'BEGIN{print (a>b)?a-b:b-a}')
      if awk -v d="$d" -v thr="$DIFF_MS" 'BEGIN{exit !(d>thr)}'; then
        if awk -v a="$r1" -v b="$r2" 'BEGIN{exit !(a<b)}'; then
          w1=256; w2=64
        else
          w1=64;  w2=256
        fi
      else
        ratio=$(awk -v a="$r1" -v b="$r2" 'BEGIN{print (a>b)?a/b:b/a}')
        if awk -v r="$ratio" -v thr="$RATIO_STRONG" 'BEGIN{exit !(r>thr)}'; then
          if awk -v a="$r1" -v b="$r2" 'BEGIN{exit !(a<b)}'; then
            w1=256; w2=64
          else
            w1=64;  w2=256
          fi
        else
          w1=128; w2=128
        fi
      fi
      state="up r1=${r1}ms r2=${r2}ms w1=$w1 w2=$w2 h1=$h1 h2=$h2"
    fi

    # Apply, respecting dwell only when both links are up (no dwell on fail/recover)
    local now; now=$(date +%s)
    local can_change=true
    if $up1 && $up2 && [[ $w1 -ne 0 && $w2 -ne 0 ]]; then
      if (( now - last_change_ts < MIN_DWELL_SEC )); then
        can_change=false
      fi
    fi

    if [[ "$state" != "$last_state" && $can_change == true ]]; then
      apply_ecmp_weights "$w1" "$w2"
      last_change_ts=$now
      log "$state (ok1=$ok1 fail1=$fail1 ok2=$ok2 fail2=$fail2)"
      last_state="$state"
    fi

    sleep "$SLEEP_SEC"
  done
}

main
