#!/bin/sh -eu

action=${1:-}
iface=${2:-}
options=${3:-}
test $# -gt 2 || action=help

comment="pro custodibus agent forward script"

help() {
    cat << EOF >&2
Pro Custodibus Agent forward script.

Allows/disallows forwarding connections in/out of the specified WireGuard
interface. Run as root.

Usage:
  forward ACTION IFACE OPTIONS

Options:
  all|true      Allows forwarding inbound and outbound connections
  inbound       Allows forwarding connections inbound from the WireGuard network
  outbound      Allows forwarding connections outbound to the WireGuard network
  internal      Allows forwarding connections within the WireGuard network
  clean|false   Cleans up forwarding

Examples:
  forward up wg0 outbound
EOF
}

has_ipv4() {
    ip -brief address show dev $iface | sed -n /\./a${1:-ipv4}
}

has_ipv6() {
    ip -brief address show dev $iface | sed -n /:/a${1:-ipv6}
}

# outputs rule number of catch-all reject rule or blank
rulenum_of_catchall() {
    local exe=$1
    local table=$3
    local chain=$5
    $exe -t $table -S $chain |
        sed '/^-P/d' |
        sed -nE "/^-A $chain -j (REJECT|DROP)/{=;q}"
}

clean() {
    for exe in iptables ip6tables; do
        $exe-save | awk -v exe="$exe" -v iface="$iface" -v comment="$comment" '
        BEGIN { regex = " -[io] " iface " .*" comment }
        /^\*/ { sub("\*", ""); table = $0 }
        $0~regex {
            sub("^-A", "-D");
            cmd = exe " -t " table " " $0;
            print "+ " cmd; system(cmd)
        }
        ' >&2
    done
}

firewall() {
    local rule="$*"
    for exe in $(has_ipv4 iptables) $(has_ipv6 ip6tables); do
        rulenum=$(rulenum_of_catchall $exe $rule)
        if [ "$rulenum" ]; then
            # insert new rule just before catch-all reject rule
            rule="$(echo $rule | sed -E "s/-A (\w+)/-I \1 $rulenum/")"
        fi
        echo + $exe $rule -m comment --comment '"'$comment'"' >&2
        $exe $rule -m comment --comment "$comment"
    done
}

net_param() {
    for version in $(has_ipv4) $(has_ipv6); do
        if [ "$(sysctl -n net.$version.$1)" != "$2" ]; then
            echo + sysctl -w net.$version.$1='"'$2'"'
            sysctl -w net.$version.$1="$2"
        fi
    done
}

post_up_inbound() {
    firewall -t filter -A FORWARD -i $iface -j ACCEPT
}

post_up_outbound() {
    firewall -t filter -A FORWARD -o $iface -j ACCEPT
}

post_up_internal() {
    firewall -t filter -A FORWARD -i $iface -o $iface -j ACCEPT
}

post_up() {
    clean
    net_param conf.all.forwarding 1
    case "$options" in
        all|true) post_up_inbound; post_up_outbound ;;
        inbound) post_up_inbound ;;
        outbound) post_up_outbound ;;
        internal) post_up_internal ;;
    esac
}

pre_down() {
    clean
}

case $action in
    pre_up) ;;
    up|post_up) post_up ;;
    down|pre_down) pre_down ;;
    post_down) ;;
    *) help ;;
esac
