//! This package provides instrumentation for creating Berkeley Packet Filter[1]
//! (BPF) programs, along with a simulator for running them.
//!
//! BPF is a mechanism for cheap, in-kernel packet filtering. Programs are
//! attached to a network device and executed for every packet that flows
//! through it. The program must then return a verdict: the amount of packet
//! bytes that the kernel should copy into userspace. Execution speed is
//! achieved by having programs run in a limited virtual machine, which has the
//! added benefit of graceful failure in the face of buggy programs.
//!
//! The BPF virtual machine has a 32-bit word length and a small number of
//! word-sized registers:
//!
//! - The accumulator, `a`: The source/destination of arithmetic and logic
//!   operations.
//! - The index register, `x`: Used as an offset for indirect memory access and
//!   as a comparison value for conditional jumps.
//! - The scratch memory store, `M[0]..M[15]`: Used for saving the value of a/x
//!   for later use.
//!
//! The packet being examined is an array of bytes, and is addressed using plain
//! array subscript notation, e.g. [10] for the byte at offset 10. An implicit
//! program counter, `pc`, is intialized to zero and incremented for each instruction.
//!
//! The machine has a fixed instruction set with the following form, where the
//! numbers represent bit length:
//!
//! ```
//! ┌───────────┬──────┬──────┐
//! │ opcode:16 │ jt:8 │ jt:8 │
//! ├───────────┴──────┴──────┤
//! │           k:32          │
//! └─────────────────────────┘
//! ```
//!
//! The `opcode` indicates the instruction class and its addressing mode.
//! Opcodes are generated by performing binary addition on the 8-bit class and
//! mode constants. For example, the opcode for loading a byte from the packet
//! at X + 2, (`ldb [x + 2]`), is:
//!
//! ```
//! LD | IND | B = 0x00 | 0x40 | 0x20
//!              = 0x60
//! ```
//!
//! `jt` is an offset used for conditional jumps, and increments the program
//! counter by its amount if the comparison was true. Conversely, `jf`
//! increments the counter if it was false. These fields are ignored in all
//! other cases. `k` is a generic variable used for various purposes, most
//! commonly as some sort of constant.
//!
//! This package contains opcode extensions used by different implementations,
//! where "extension" is anything outside of the original that was imported into
//! 4.4BSD[2]. These are marked with "EXTENSION", along with a list of
//! implementations that use them.
//!
//! Most of the doc-comments use the BPF assembly syntax as described in the
//! original paper[1]. For the sake of completeness, here is the complete
//! instruction set, along with the extensions:
//!
//!```
//! opcode  addressing modes
//! ld      #k  #len    M[k]    [k]     [x + k]
//! ldh     [k] [x + k]
//! ldb     [k] [x + k]
//! ldx     #k  #len    M[k]    4 * ([k] & 0xf) arc4random()
//! st      M[k]
//! stx     M[k]
//! jmp     L
//! jeq     #k, Lt, Lf
//! jgt     #k, Lt, Lf
//! jge     #k, Lt, Lf
//! jset    #k, Lt, Lf
//! add     #k  x
//! sub     #k  x
//! mul     #k  x
//! div     #k  x
//! or      #k  x
//! and     #k  x
//! lsh     #k  x
//! rsh     #k  x
//! neg     #k  x
//! mod     #k  x
//! xor     #k  x
//! ret     #k  a
//! tax
//! txa
//! ```
//!
//! Finally, a note on program design. The lack of backwards jumps leads to a
//! "return early, return often" control flow. Take for example the program
//! generated from the tcpdump filter `ip`:
//!
//! ```
//! (000) ldh   [12]            ; Ethernet Packet Type
//! (001) jeq   #0x86dd, 2, 7   ; ETHERTYPE_IPV6
//! (002) ldb   [20]            ; IPv6 Next Header
//! (003) jeq   #0x6, 10, 4     ; TCP
//! (004) jeq   #0x2c, 5, 11    ; IPv6 Fragment Header
//! (005) ldb   [54]            ; TCP Source Port
//! (006) jeq   #0x6, 10, 11    ; IPPROTO_TCP
//! (007) jeq   #0x800, 8, 11   ; ETHERTYPE_IP
//! (008) ldb   [23]            ; IPv4 Protocol
//! (009) jeq   #0x6, 10, 11    ; IPPROTO_TCP
//! (010) ret   #262144         ; copy 0x40000
//! (011) ret   #0              ; skip packet
//! ```
//!
//! Here we can make a few observations:
//!
//! - The problem "filter only tcp packets" has essentially been transformed
//!   into a series of layer checks.
//! - There are two distinct branches in the code, one for validating IPv4
//!   headers and one for IPv6 headers.
//! - Most conditional jumps in these branches lead directly to the last two
//!   instructions, a pass or fail. Thus the goal of a program is to find the
//!   fastest route to a pass/fail comparison.
//!
//! [1]: S. McCanne and V. Jacobson, "The BSD Packet Filter: A New Architecture
//!      for User-level Packet Capture", Proceedings of the 1993 Winter USENIX.
//! [2]: https://minnie.tuhs.org/cgi-bin/utree.pl?file=4.4BSD/usr/src/sys/net/bpf.h
const std = @import("std");
const builtin = @import("builtin");
const native_endian = builtin.target.cpu.arch.endian();
const mem = std.mem;
const math = std.math;
const random = std.crypto.random;
const assert = std.debug.assert;
const expectEqual = std.testing.expectEqual;
const expectError = std.testing.expectError;
const expect = std.testing.expect;

// instruction classes

/// ld, ldh, ldb: Load data into a.
pub const LD = 0x00;
/// ldx: Load data into x.
pub const LDX = 0x01;
/// st:  Store into scratch memory the value of a.
pub const ST = 0x02;
/// st:  Store into scratch memory the value of x.
pub const STX = 0x03;
/// alu: Wrapping arithmetic/bitwise operations on a using the value of k/x.
pub const ALU = 0x04;
/// jmp, jeq, jgt, je, jset: Increment the program counter based on a comparison
/// between k/x and the accumulator.
pub const JMP = 0x05;
/// ret: Return a verdict using the value of k/the accumulator.
pub const RET = 0x06;
/// tax, txa: Register value copying between X and a.
pub const MISC = 0x07;

// Size of data to be loaded from the packet.

/// ld: 32-bit full word.
pub const W = 0x00;
/// ldh: 16-bit half word.
pub const H = 0x08;
/// ldb: Single byte.
pub const B = 0x10;

// Addressing modes used for loads to a/x.

/// #k: The immediate value stored in k.
pub const IMM = 0x00;
/// [k]: The value at offset k in the packet.
pub const ABS = 0x20;
/// [x + k]: The value at offset x + k in the packet.
pub const IND = 0x40;
/// M[k]: The value of the k'th scratch memory register.
pub const MEM = 0x60;
/// #len: The size of the packet.
pub const LEN = 0x80;
/// 4 * ([k] & 0xf): Four times the low four bits of the byte at offset k in the
/// packet. This is used for efficiently loading the header length of an IP
/// packet.
pub const MSH = 0xa0;
/// arc4random: 32-bit integer generated from a CPRNG (see arc4random(3)) loaded into a.
/// EXTENSION. Defined for:
/// - OpenBSD.
pub const RND = 0xc0;

// Modifiers for different instruction classes.

/// Use the value of k for alu operations (add #k).
/// Compare against the value of k for jumps (jeq #k, Lt, Lf).
/// Return the value of k for returns (ret #k).
pub const K = 0x00;
/// Use the value of x for alu operations (add x).
/// Compare against the value of X for jumps (jeq x, Lt, Lf).
pub const X = 0x08;
/// Return the value of a for returns (ret a).
pub const A = 0x10;

// ALU Operations on a using the value of k/x.

// All arithmetic operations are defined to overflow the value of a.

/// add: a = a + k
///      a = a + x.
pub const ADD = 0x00;
/// sub: a = a - k
///      a = a - x.
pub const SUB = 0x10;
/// mul: a = a * k
///      a = a * x.
pub const MUL = 0x20;
/// div: a = a / k
///      a = a / x.
/// Truncated division.
pub const DIV = 0x30;
/// or:  a = a | k
///      a = a | x.
pub const OR = 0x40;
/// and: a = a & k
///      a = a & x.
pub const AND = 0x50;
/// lsh: a = a << k
///      a = a << x.
/// a = a << k, a = a << x.
pub const LSH = 0x60;
/// rsh: a = a >> k
///      a = a >> x.
pub const RSH = 0x70;
/// neg: a = -a.
/// Note that this isn't a binary negation, rather the value of `~a + 1`.
pub const NEG = 0x80;
/// mod: a = a % k
///      a = a % x.
/// EXTENSION. Defined for:
///  - Linux.
///  - NetBSD + Minix 3.
///  - FreeBSD and derivitives.
pub const MOD = 0x90;
/// xor: a = a ^ k
///      a = a ^ x.
/// EXTENSION. Defined for:
///  - Linux.
///  - NetBSD + Minix 3.
///  - FreeBSD and derivitives.
pub const XOR = 0xa0;

// Jump operations using a comparison between a and x/k.

/// jmp L: pc += k.
/// No comparison done here.
pub const JA = 0x00;
/// jeq    #k, Lt, Lf: pc += (a == k)    ? jt : jf.
/// jeq     x, Lt, Lf: pc += (a == x)    ? jt : jf.
pub const JEQ = 0x10;
/// jgt    #k, Lt, Lf: pc += (a >  k)    ? jt : jf.
/// jgt     x, Lt, Lf: pc += (a >  x)    ? jt : jf.
pub const JGT = 0x20;
/// jge    #k, Lt, Lf: pc += (a >= k)    ? jt : jf.
/// jge     x, Lt, Lf: pc += (a >= x)    ? jt : jf.
pub const JGE = 0x30;
/// jset   #k, Lt, Lf: pc += (a & k > 0) ? jt : jf.
/// jset    x, Lt, Lf: pc += (a & x > 0) ? jt : jf.
pub const JSET = 0x40;

// Miscellaneous operations/register copy.

/// tax: x = a.
pub const TAX = 0x00;
/// txa: a = x.
pub const TXA = 0x80;

/// The 16 registers in the scratch memory store as named enums.
pub const Scratch = enum(u4) { m0, m1, m2, m3, m4, m5, m6, m7, m8, m9, m10, m11, m12, m13, m14, m15 };
pub const MEMWORDS = 16;
pub const MAXINSNS = switch (builtin.os.tag) {
    .linux => 4096,
    else => 512,
};
pub const MINBUFSIZE = 32;
pub const MAXBUFSIZE = 1 << 21;

pub const Insn = extern struct {
    opcode: u16,
    jt: u8,
    jf: u8,
    k: u32,

    /// Implements the `std.fmt.format` API.
    /// The formatting is similar to the output of tcpdump -dd.
    pub fn format(
        self: Insn,
        comptime layout: []const u8,
        opts: std.fmt.FormatOptions,
        writer: anytype,
    ) !void {
        _ = opts;
        if (comptime layout.len != 0 and layout[0] != 's')
            @compileError("Unsupported format specifier for BPF Insn type '" ++ layout ++ "'.");

        try std.fmt.format(
            writer,
            "Insn{{ 0x{X:0<2}, {d}, {d}, 0x{X:0<8} }}",
            .{ self.opcode, self.jt, self.jf, self.k },
        );
    }

    const Size = enum(u8) {
        word = W,
        half_word = H,
        byte = B,
    };

    fn stmt(opcode: u16, k: u32) Insn {
        return .{
            .opcode = opcode,
            .jt = 0,
            .jf = 0,
            .k = k,
        };
    }

    pub fn ld_imm(value: u32) Insn {
        return stmt(LD | IMM, value);
    }

    pub fn ld_abs(size: Size, offset: u32) Insn {
        return stmt(LD | ABS | @enumToInt(size), offset);
    }

    pub fn ld_ind(size: Size, offset: u32) Insn {
        return stmt(LD | IND | @enumToInt(size), offset);
    }

    pub fn ld_mem(reg: Scratch) Insn {
        return stmt(LD | MEM, @enumToInt(reg));
    }

    pub fn ld_len() Insn {
        return stmt(LD | LEN | W, 0);
    }

    pub fn ld_rnd() Insn {
        return stmt(LD | RND | W, 0);
    }

    pub fn ldx_imm(value: u32) Insn {
        return stmt(LDX | IMM, value);
    }

    pub fn ldx_mem(reg: Scratch) Insn {
        return stmt(LDX | MEM, @enumToInt(reg));
    }

    pub fn ldx_len() Insn {
        return stmt(LDX | LEN | W, 0);
    }

    pub fn ldx_msh(offset: u32) Insn {
        return stmt(LDX | MSH | B, offset);
    }

    pub fn st(reg: Scratch) Insn {
        return stmt(ST, @enumToInt(reg));
    }
    pub fn stx(reg: Scratch) Insn {
        return stmt(STX, @enumToInt(reg));
    }

    const AluOp = enum(u16) {
        add = ADD,
        sub = SUB,
        mul = MUL,
        div = DIV,
        @"or" = OR,
        @"and" = AND,
        lsh = LSH,
        rsh = RSH,
        mod = MOD,
        xor = XOR,
    };

    const Source = enum(u16) {
        k = K,
        x = X,
    };
    const KOrX = union(Source) {
        k: u32,
        x: void,
    };

    pub fn alu_neg() Insn {
        return stmt(ALU | NEG, 0);
    }

    pub fn alu(op: AluOp, source: KOrX) Insn {
        return stmt(
            ALU | @enumToInt(op) | @enumToInt(source),
            if (source == .k) source.k else 0,
        );
    }

    const JmpOp = enum(u16) {
        jeq = JEQ,
        jgt = JGT,
        jge = JGE,
        jset = JSET,
    };

    pub fn jmp_ja(location: u32) Insn {
        return stmt(JMP | JA, location);
    }

    pub fn jmp(op: JmpOp, source: KOrX, jt: u8, jf: u8) Insn {
        return Insn{
            .opcode = JMP | @enumToInt(op) | @enumToInt(source),
            .jt = jt,
            .jf = jf,
            .k = if (source == .k) source.k else 0,
        };
    }

    const Verdict = enum(u16) {
        k = K,
        a = A,
    };
    const KOrA = union(Verdict) {
        k: u32,
        a: void,
    };

    pub fn ret(verdict: KOrA) Insn {
        return stmt(
            RET | @enumToInt(verdict),
            if (verdict == .k) verdict.k else 0,
        );
    }

    pub fn tax() Insn {
        return stmt(MISC | TAX, 0);
    }

    pub fn txa() Insn {
        return stmt(MISC | TXA, 0);
    }
};

fn opcodeEqual(opcode: u16, insn: Insn) !void {
    try expectEqual(opcode, insn.opcode);
}

test "opcodes" {
    try opcodeEqual(0x00, Insn.ld_imm(0));
    try opcodeEqual(0x20, Insn.ld_abs(.word, 0));
    try opcodeEqual(0x28, Insn.ld_abs(.half_word, 0));
    try opcodeEqual(0x30, Insn.ld_abs(.byte, 0));
    try opcodeEqual(0x40, Insn.ld_ind(.word, 0));
    try opcodeEqual(0x48, Insn.ld_ind(.half_word, 0));
    try opcodeEqual(0x50, Insn.ld_ind(.byte, 0));
    try opcodeEqual(0x60, Insn.ld_mem(.m0));
    try opcodeEqual(0x80, Insn.ld_len());
    try opcodeEqual(0xc0, Insn.ld_rnd());

    try opcodeEqual(0x01, Insn.ldx_imm(0));
    try opcodeEqual(0x61, Insn.ldx_mem(.m0));
    try opcodeEqual(0x81, Insn.ldx_len());
    try opcodeEqual(0xb1, Insn.ldx_msh(0));

    try opcodeEqual(0x02, Insn.st(.m0));
    try opcodeEqual(0x03, Insn.stx(.m0));

    try opcodeEqual(0x04, Insn.alu(.add, .{ .k = 0 }));
    try opcodeEqual(0x14, Insn.alu(.sub, .{ .k = 0 }));
    try opcodeEqual(0x24, Insn.alu(.mul, .{ .k = 0 }));
    try opcodeEqual(0x34, Insn.alu(.div, .{ .k = 0 }));
    try opcodeEqual(0x44, Insn.alu(.@"or", .{ .k = 0 }));
    try opcodeEqual(0x54, Insn.alu(.@"and", .{ .k = 0 }));
    try opcodeEqual(0x64, Insn.alu(.lsh, .{ .k = 0 }));
    try opcodeEqual(0x74, Insn.alu(.rsh, .{ .k = 0 }));
    try opcodeEqual(0x94, Insn.alu(.mod, .{ .k = 0 }));
    try opcodeEqual(0xa4, Insn.alu(.xor, .{ .k = 0 }));
    try opcodeEqual(0x84, Insn.alu_neg());
    try opcodeEqual(0x0c, Insn.alu(.add, .x));
    try opcodeEqual(0x1c, Insn.alu(.sub, .x));
    try opcodeEqual(0x2c, Insn.alu(.mul, .x));
    try opcodeEqual(0x3c, Insn.alu(.div, .x));
    try opcodeEqual(0x4c, Insn.alu(.@"or", .x));
    try opcodeEqual(0x5c, Insn.alu(.@"and", .x));
    try opcodeEqual(0x6c, Insn.alu(.lsh, .x));
    try opcodeEqual(0x7c, Insn.alu(.rsh, .x));
    try opcodeEqual(0x9c, Insn.alu(.mod, .x));
    try opcodeEqual(0xac, Insn.alu(.xor, .x));

    try opcodeEqual(0x05, Insn.jmp_ja(0));
    try opcodeEqual(0x15, Insn.jmp(.jeq, .{ .k = 0 }, 0, 0));
    try opcodeEqual(0x25, Insn.jmp(.jgt, .{ .k = 0 }, 0, 0));
    try opcodeEqual(0x35, Insn.jmp(.jge, .{ .k = 0 }, 0, 0));
    try opcodeEqual(0x45, Insn.jmp(.jset, .{ .k = 0 }, 0, 0));
    try opcodeEqual(0x1d, Insn.jmp(.jeq, .x, 0, 0));
    try opcodeEqual(0x2d, Insn.jmp(.jgt, .x, 0, 0));
    try opcodeEqual(0x3d, Insn.jmp(.jge, .x, 0, 0));
    try opcodeEqual(0x4d, Insn.jmp(.jset, .x, 0, 0));

    try opcodeEqual(0x06, Insn.ret(.{ .k = 0 }));
    try opcodeEqual(0x16, Insn.ret(.a));

    try opcodeEqual(0x07, Insn.tax());
    try opcodeEqual(0x87, Insn.txa());
}

pub const Error = error{
    InvalidOpcode,
    InvalidOffset,
    InvalidLocation,
    DivisionByZero,
    NoReturn,
};

/// A simple implementation of the BPF virtual-machine.
/// Use this to run/debug programs.
pub fn simulate(
    packet: []const u8,
    filter: []const Insn,
    byte_order: std.builtin.Endian,
) Error!u32 {
    assert(filter.len > 0 and filter.len < MAXINSNS);
    assert(packet.len < MAXBUFSIZE);
    const len = @intCast(u32, packet.len);

    var a: u32 = 0;
    var x: u32 = 0;
    var m = mem.zeroes([MEMWORDS]u32);
    var pc: usize = 0;

    while (pc < filter.len) : (pc += 1) {
        const i = filter[pc];
        // Cast to a wider type to protect against overflow.

        const k = @as(u64, i.k);
        const remaining = filter.len - (pc + 1);

        // Do validation/error checking here to compress the second switch.

        switch (i.opcode) {
            LD | ABS | W => if (k + @sizeOf(u32) - 1 >= packet.len) return error.InvalidOffset,
            LD | ABS | H => if (k + @sizeOf(u16) - 1 >= packet.len) return error.InvalidOffset,
            LD | ABS | B => if (k >= packet.len) return error.InvalidOffset,
            LD | IND | W => if (k + x + @sizeOf(u32) - 1 >= packet.len) return error.InvalidOffset,
            LD | IND | H => if (k + x + @sizeOf(u16) - 1 >= packet.len) return error.InvalidOffset,
            LD | IND | B => if (k + x >= packet.len) return error.InvalidOffset,

            LDX | MSH | B => if (k >= packet.len) return error.InvalidOffset,
            ST, STX, LD | MEM, LDX | MEM => if (i.k >= MEMWORDS) return error.InvalidOffset,

            JMP | JA => if (remaining <= i.k) return error.InvalidOffset,
            JMP | JEQ | K,
            JMP | JGT | K,
            JMP | JGE | K,
            JMP | JSET | K,
            JMP | JEQ | X,
            JMP | JGT | X,
            JMP | JGE | X,
            JMP | JSET | X,
            => if (remaining <= i.jt or remaining <= i.jf) return error.InvalidLocation,
            else => {},
        }
        switch (i.opcode) {
            LD | IMM => a = i.k,
            LD | MEM => a = m[i.k],
            LD | LEN | W => a = len,
            LD | RND | W => a = random.int(u32),
            LD | ABS | W => a = mem.readInt(u32, packet[i.k..][0..@sizeOf(u32)], byte_order),
            LD | ABS | H => a = mem.readInt(u16, packet[i.k..][0..@sizeOf(u16)], byte_order),
            LD | ABS | B => a = packet[i.k],
            LD | IND | W => a = mem.readInt(u32, packet[i.k + x ..][0..@sizeOf(u32)], byte_order),
            LD | IND | H => a = mem.readInt(u16, packet[i.k + x ..][0..@sizeOf(u16)], byte_order),
            LD | IND | B => a = packet[i.k + x],

            LDX | IMM => x = i.k,
            LDX | MEM => x = m[i.k],
            LDX | LEN | W => x = len,
            LDX | MSH | B => x = @as(u32, @truncate(u4, packet[i.k])) << 2,

            ST => m[i.k] = a,
            STX => m[i.k] = x,

            ALU | ADD | K => a +%= i.k,
            ALU | SUB | K => a -%= i.k,
            ALU | MUL | K => a *%= i.k,
            ALU | DIV | K => a = try math.divTrunc(u32, a, i.k),
            ALU | OR | K => a |= i.k,
            ALU | AND | K => a &= i.k,
            ALU | LSH | K => a = math.shl(u32, a, i.k),
            ALU | RSH | K => a = math.shr(u32, a, i.k),
            ALU | MOD | K => a = try math.mod(u32, a, i.k),
            ALU | XOR | K => a ^= i.k,
            ALU | ADD | X => a +%= x,
            ALU | SUB | X => a -%= x,
            ALU | MUL | X => a *%= x,
            ALU | DIV | X => a = try math.divTrunc(u32, a, x),
            ALU | OR | X => a |= x,
            ALU | AND | X => a &= x,
            ALU | LSH | X => a = math.shl(u32, a, x),
            ALU | RSH | X => a = math.shr(u32, a, x),
            ALU | MOD | X => a = try math.mod(u32, a, x),
            ALU | XOR | X => a ^= x,
            ALU | NEG => a = @bitCast(u32, -%@bitCast(i32, a)),

            JMP | JA => pc += i.k,
            JMP | JEQ | K => pc += if (a == i.k) i.jt else i.jf,
            JMP | JGT | K => pc += if (a > i.k) i.jt else i.jf,
            JMP | JGE | K => pc += if (a >= i.k) i.jt else i.jf,
            JMP | JSET | K => pc += if (a & i.k > 0) i.jt else i.jf,
            JMP | JEQ | X => pc += if (a == x) i.jt else i.jf,
            JMP | JGT | X => pc += if (a > x) i.jt else i.jf,
            JMP | JGE | X => pc += if (a >= x) i.jt else i.jf,
            JMP | JSET | X => pc += if (a & x > 0) i.jt else i.jf,

            RET | K => return i.k,
            RET | A => return a,

            MISC | TAX => x = a,
            MISC | TXA => a = x,
            else => return error.InvalidOpcode,
        }
    }

    return error.NoReturn;
}

// This program is the BPF form of the tcpdump filter:

//

//     tcpdump -dd 'ip host mirror.internode.on.net and tcp port ftp-data'

//

// As of January 2022, mirror.internode.on.net resolves to 150.101.135.3

//

// For reference, here's what it looks like in BPF assembler.

// Note that the jumps are used for TCP/IP layer checks.

//

// ```

//       ldh [12] (#proto)

//       jeq #0x0800 (ETHERTYPE_IP), L1, fail

// L1:   ld [26]

//       jeq #150.101.135.3, L2, dest

// dest: ld [30]

//       jeq #150.101.135.3, L2, fail

// L2:   ldb [23]

//       jeq #0x6 (IPPROTO_TCP), L3, fail

// L3:   ldh [20]

//       jset #0x1fff, fail, plen

// plen: ldx 4 * ([14] & 0xf)

//       ldh [x + 14]

//       jeq  #0x14 (FTP), pass, dstp

// dstp: ldh [x + 16]

//       jeq  #0x14 (FTP), pass, fail

// pass: ret #0x40000

// fail: ret #0

// ```

const tcpdump_filter = [_]Insn{
    Insn.ld_abs(.half_word, 12),
    Insn.jmp(.jeq, .{ .k = 0x800 }, 0, 14),
    Insn.ld_abs(.word, 26),
    Insn.jmp(.jeq, .{ .k = 0x96658703 }, 2, 0),
    Insn.ld_abs(.word, 30),
    Insn.jmp(.jeq, .{ .k = 0x96658703 }, 0, 10),
    Insn.ld_abs(.byte, 23),
    Insn.jmp(.jeq, .{ .k = 0x6 }, 0, 8),
    Insn.ld_abs(.half_word, 20),
    Insn.jmp(.jset, .{ .k = 0x1fff }, 6, 0),
    Insn.ldx_msh(14),
    Insn.ld_ind(.half_word, 14),
    Insn.jmp(.jeq, .{ .k = 0x14 }, 2, 0),
    Insn.ld_ind(.half_word, 16),
    Insn.jmp(.jeq, .{ .k = 0x14 }, 0, 1),
    Insn.ret(.{ .k = 0x40000 }),
    Insn.ret(.{ .k = 0 }),
};

// This packet is the output of `ls` on mirror.internode.on.net:/, captured

// using the filter above.

//

// zig fmt: off

const ftp_data = [_]u8{
    // ethernet - 14 bytes: IPv4(0x0800) from a4:71:74:ad:4b:f0 -> de:ad:be:ef:f0:0f

    0xde, 0xad, 0xbe, 0xef, 0xf0, 0x0f, 0xa4, 0x71, 0x74, 0xad, 0x4b, 0xf0, 0x08, 0x00,
    // IPv4 - 20 bytes: TCP data from 150.101.135.3 -> 192.168.1.3

    0x45, 0x00, 0x01, 0xf2, 0x70, 0x3b, 0x40, 0x00, 0x37, 0x06, 0xf2, 0xb6,
    0x96, 0x65, 0x87, 0x03, 0xc0, 0xa8, 0x01, 0x03,
    // TCP - 32 bytes: Source port: 20 (FTP). Payload = 446 bytes

    0x00, 0x14, 0x80, 0x6d, 0x35, 0x81, 0x2d, 0x40, 0x4f, 0x8a, 0x29, 0x9e, 0x80, 0x18, 0x00, 0x2e,
    0x88, 0x8d, 0x00, 0x00, 0x01, 0x01, 0x08, 0x0a, 0x0b, 0x59, 0x5d, 0x09, 0x32, 0x8b, 0x51, 0xa0
} ++
    // Raw line-based FTP data - 446 bytes

    "lrwxrwxrwx   1 root     root           12 Feb 14  2012 debian -> .pub2/debian\r\n" ++
    "lrwxrwxrwx   1 root     root           15 Feb 14  2012 debian-cd -> .pub2/debian-cd\r\n" ++
    "lrwxrwxrwx   1 root     root            9 Mar  9  2018 linux -> pub/linux\r\n" ++
    "drwxr-xr-X   3 mirror   mirror       4096 Sep 20 08:10 pub\r\n" ++
    "lrwxrwxrwx   1 root     root           12 Feb 14  2012 ubuntu -> .pub2/ubuntu\r\n" ++
    "-rw-r--r--   1 root     root         1044 Jan 20  2015 welcome.msg\r\n";
// zig fmt: on


test "tcpdump filter" {
    try expectEqual(
        @as(u32, 0x40000),
        try simulate(ftp_data, &tcpdump_filter, .Big),
    );
}

fn expectPass(data: anytype, filter: []const Insn) !void {
    try expectEqual(
        @as(u32, 0),
        try simulate(mem.asBytes(data), filter, .Big),
    );
}

fn expectFail(expected_error: anyerror, data: anytype, filter: []const Insn) !void {
    try expectError(
        expected_error,
        simulate(mem.asBytes(data), filter, native_endian),
    );
}

test "simulator coverage" {
    const some_data = [_]u8{
        0xaa, 0xbb, 0xcc, 0xdd, 0x7f,
    };

    try expectPass(&some_data, &.{
        // ld  #10

        // ldx #1

        // st M[0]

        // stx M[1]

        // fail if A != 10

        Insn.ld_imm(10),
        Insn.ldx_imm(1),
        Insn.st(.m0),
        Insn.stx(.m1),
        Insn.jmp(.jeq, .{ .k = 10 }, 1, 0),
        Insn.ret(.{ .k = 1 }),
        // ld [0]

        // fail if A != 0xaabbccdd

        Insn.ld_abs(.word, 0),
        Insn.jmp(.jeq, .{ .k = 0xaabbccdd }, 1, 0),
        Insn.ret(.{ .k = 2 }),
        // ldh [0]

        // fail if A != 0xaabb

        Insn.ld_abs(.half_word, 0),
        Insn.jmp(.jeq, .{ .k = 0xaabb }, 1, 0),
        Insn.ret(.{ .k = 3 }),
        // ldb [0]

        // fail if A != 0xaa

        Insn.ld_abs(.byte, 0),
        Insn.jmp(.jeq, .{ .k = 0xaa }, 1, 0),
        Insn.ret(.{ .k = 4 }),
        // ld [x + 0]

        // fail if A != 0xbbccdd7f

        Insn.ld_ind(.word, 0),
        Insn.jmp(.jeq, .{ .k = 0xbbccdd7f }, 1, 0),
        Insn.ret(.{ .k = 5 }),
        // ldh [x + 0]

        // fail if A != 0xbbcc

        Insn.ld_ind(.half_word, 0),
        Insn.jmp(.jeq, .{ .k = 0xbbcc }, 1, 0),
        Insn.ret(.{ .k = 6 }),
        // ldb [x + 0]

        // fail if A != 0xbb

        Insn.ld_ind(.byte, 0),
        Insn.jmp(.jeq, .{ .k = 0xbb }, 1, 0),
        Insn.ret(.{ .k = 7 }),
        // ld M[0]

        // fail if A != 10

        Insn.ld_mem(.m0),
        Insn.jmp(.jeq, .{ .k = 10 }, 1, 0),
        Insn.ret(.{ .k = 8 }),
        // ld #len

        // fail if A != 5

        Insn.ld_len(),
        Insn.jmp(.jeq, .{ .k = some_data.len }, 1, 0),
        Insn.ret(.{ .k = 9 }),
        // ld #0

        // ld arc4random()

        // fail if A == 0

        Insn.ld_imm(0),
        Insn.ld_rnd(),
        Insn.jmp(.jgt, .{ .k = 0 }, 1, 0),
        Insn.ret(.{ .k = 10 }),
        // ld  #3

        // ldx #10

        // st M[2]

        // txa

        // fail if a != x

        Insn.ld_imm(3),
        Insn.ldx_imm(10),
        Insn.st(.m2),
        Insn.txa(),
        Insn.jmp(.jeq, .x, 1, 0),
        Insn.ret(.{ .k = 11 }),
        // ldx M[2]

        // fail if A <= X

        Insn.ldx_mem(.m2),
        Insn.jmp(.jgt, .x, 1, 0),
        Insn.ret(.{ .k = 12 }),
        // ldx #len

        // fail if a <= x

        Insn.ldx_len(),
        Insn.jmp(.jgt, .x, 1, 0),
        Insn.ret(.{ .k = 13 }),
        // a = 4 * (0x7f & 0xf)

        // x = 4 * ([4]  & 0xf)

        // fail if a != x

        Insn.ld_imm(4 * (0x7f & 0xf)),
        Insn.ldx_msh(4),
        Insn.jmp(.jeq, .x, 1, 0),
        Insn.ret(.{ .k = 14 }),
        // ld  #(u32)-1

        // ldx #2

        // add #1

        // fail if a != 0

        Insn.ld_imm(0xffffffff),
        Insn.ldx_imm(2),
        Insn.alu(.add, .{ .k = 1 }),
        Insn.jmp(.jeq, .{ .k = 0 }, 1, 0),
        Insn.ret(.{ .k = 15 }),
        // sub #1

        // fail if a != (u32)-1

        Insn.alu(.sub, .{ .k = 1 }),
        Insn.jmp(.jeq, .{ .k = 0xffffffff }, 1, 0),
        Insn.ret(.{ .k = 16 }),
        // add x

        // fail if a != 1

        Insn.alu(.add, .x),
        Insn.jmp(.jeq, .{ .k = 1 }, 1, 0),
        Insn.ret(.{ .k = 17 }),
        // sub x

        // fail if a != (u32)-1

        Insn.alu(.sub, .x),
        Insn.jmp(.jeq, .{ .k = 0xffffffff }, 1, 0),
        Insn.ret(.{ .k = 18 }),
        // ld #16

        // mul #2

        // fail if a != 32

        Insn.ld_imm(16),
        Insn.alu(.mul, .{ .k = 2 }),
        Insn.jmp(.jeq, .{ .k = 32 }, 1, 0),
        Insn.ret(.{ .k = 19 }),
        // mul x

        // fail if a != 64

        Insn.alu(.mul, .x),
        Insn.jmp(.jeq, .{ .k = 64 }, 1, 0),
        Insn.ret(.{ .k = 20 }),
        // div #2

        // fail if a != 32

        Insn.alu(.div, .{ .k = 2 }),
        Insn.jmp(.jeq, .{ .k = 32 }, 1, 0),
        Insn.ret(.{ .k = 21 }),
        // div x

        // fail if a != 16

        Insn.alu(.div, .x),
        Insn.jmp(.jeq, .{ .k = 16 }, 1, 0),
        Insn.ret(.{ .k = 22 }),
        // or #4

        // fail if a != 20

        Insn.alu(.@"or", .{ .k = 4 }),
        Insn.jmp(.jeq, .{ .k = 20 }, 1, 0),
        Insn.ret(.{ .k = 23 }),
        // or x

        // fail if a != 22

        Insn.alu(.@"or", .x),
        Insn.jmp(.jeq, .{ .k = 22 }, 1, 0),
        Insn.ret(.{ .k = 24 }),
        // and #6

        // fail if a != 6

        Insn.alu(.@"and", .{ .k = 0b110 }),
        Insn.jmp(.jeq, .{ .k = 6 }, 1, 0),
        Insn.ret(.{ .k = 25 }),
        // and x

        // fail if a != 2

        Insn.alu(.@"and", .x),
        Insn.jmp(.jeq, .x, 1, 0),
        Insn.ret(.{ .k = 26 }),
        // xor #15

        // fail if a != 13

        Insn.alu(.xor, .{ .k = 0b1111 }),
        Insn.jmp(.jeq, .{ .k = 0b1101 }, 1, 0),
        Insn.ret(.{ .k = 27 }),
        // xor x

        // fail if a != 15

        Insn.alu(.xor, .x),
        Insn.jmp(.jeq, .{ .k = 0b1111 }, 1, 0),
        Insn.ret(.{ .k = 28 }),
        // rsh #1

        // fail if a != 7

        Insn.alu(.rsh, .{ .k = 1 }),
        Insn.jmp(.jeq, .{ .k = 0b0111 }, 1, 0),
        Insn.ret(.{ .k = 29 }),
        // rsh x

        // fail if a != 1

        Insn.alu(.rsh, .x),
        Insn.jmp(.jeq, .{ .k = 0b0001 }, 1, 0),
        Insn.ret(.{ .k = 30 }),
        // lsh #1

        // fail if a != 2

        Insn.alu(.lsh, .{ .k = 1 }),
        Insn.jmp(.jeq, .{ .k = 0b0010 }, 1, 0),
        Insn.ret(.{ .k = 31 }),
        // lsh x

        // fail if a != 8

        Insn.alu(.lsh, .x),
        Insn.jmp(.jeq, .{ .k = 0b1000 }, 1, 0),
        Insn.ret(.{ .k = 32 }),
        // mod 6

        // fail if a != 2

        Insn.alu(.mod, .{ .k = 6 }),
        Insn.jmp(.jeq, .{ .k = 2 }, 1, 0),
        Insn.ret(.{ .k = 33 }),
        // mod x

        // fail if a != 0

        Insn.alu(.mod, .x),
        Insn.jmp(.jeq, .{ .k = 0 }, 1, 0),
        Insn.ret(.{ .k = 34 }),
        // tax

        // neg

        // fail if a != (u32)-2

        Insn.txa(),
        Insn.alu_neg(),
        Insn.jmp(.jeq, .{ .k = ~@as(u32, 2) + 1 }, 1, 0),
        Insn.ret(.{ .k = 35 }),
        // ja #1 (skip the next instruction)

        Insn.jmp_ja(1),
        Insn.ret(.{ .k = 36 }),
        // ld #20

        // tax

        // fail if a != 20

        // fail if a != x

        Insn.ld_imm(20),
        Insn.tax(),
        Insn.jmp(.jeq, .{ .k = 20 }, 1, 0),
        Insn.ret(.{ .k = 37 }),
        Insn.jmp(.jeq, .x, 1, 0),
        Insn.ret(.{ .k = 38 }),
        // ld #19

        // fail if a == 20

        // fail if a == x

        // fail if a >= 20

        // fail if a >= X

        Insn.ld_imm(19),
        Insn.jmp(.jeq, .{ .k = 20 }, 0, 1),
        Insn.ret(.{ .k = 39 }),
        Insn.jmp(.jeq, .x, 0, 1),
        Insn.ret(.{ .k = 40 }),
        Insn.jmp(.jgt, .{ .k = 20 }, 0, 1),
        Insn.ret(.{ .k = 41 }),
        Insn.jmp(.jgt, .x, 0, 1),
        Insn.ret(.{ .k = 42 }),
        // ld #21

        // fail if a < 20

        // fail if a < x

        Insn.ld_imm(21),
        Insn.jmp(.jgt, .{ .k = 20 }, 1, 0),
        Insn.ret(.{ .k = 43 }),
        Insn.jmp(.jgt, .x, 1, 0),
        Insn.ret(.{ .k = 44 }),
        // ldx #22

        // fail if a < 22

        // fail if a < x

        Insn.ldx_imm(22),
        Insn.jmp(.jge, .{ .k = 22 }, 0, 1),
        Insn.ret(.{ .k = 45 }),
        Insn.jmp(.jge, .x, 0, 1),
        Insn.ret(.{ .k = 46 }),
        // ld #23

        // fail if a >= 22

        // fail if a >= x

        Insn.ld_imm(23),
        Insn.jmp(.jge, .{ .k = 22 }, 1, 0),
        Insn.ret(.{ .k = 47 }),
        Insn.jmp(.jge, .x, 1, 0),
        Insn.ret(.{ .k = 48 }),
        // ldx #0b10100

        // fail if a & 0b10100 == 0

        // fail if a & x       == 0

        Insn.ldx_imm(0b10100),
        Insn.jmp(.jset, .{ .k = 0b10100 }, 1, 0),
        Insn.ret(.{ .k = 47 }),
        Insn.jmp(.jset, .x, 1, 0),
        Insn.ret(.{ .k = 48 }),
        // ldx #0

        // fail if a & 0 > 0

        // fail if a & x > 0

        Insn.ldx_imm(0),
        Insn.jmp(.jset, .{ .k = 0 }, 0, 1),
        Insn.ret(.{ .k = 49 }),
        Insn.jmp(.jset, .x, 0, 1),
        Insn.ret(.{ .k = 50 }),
        Insn.ret(.{ .k = 0 }),
    });
    try expectPass(&some_data, &.{
        Insn.ld_imm(35),
        Insn.ld_imm(0),
        Insn.ret(.a),
    });

    // Errors

    try expectFail(error.NoReturn, &some_data, &.{
        Insn.ld_imm(10),
    });
    try expectFail(error.InvalidOpcode, &some_data, &.{
        Insn.stmt(0x7f, 0xdeadbeef),
    });
    try expectFail(error.InvalidOffset, &some_data, &.{
        Insn.stmt(LD | ABS | W, 10),
    });
    try expectFail(error.InvalidLocation, &some_data, &.{
        Insn.jmp(.jeq, .{ .k = 0 }, 10, 0),
    });
    try expectFail(error.InvalidLocation, &some_data, &.{
        Insn.jmp(.jeq, .{ .k = 0 }, 0, 10),
    });
}