const std = @import("../../std.zig");
const builtin = @import("builtin");

const io = std.io;
const os = std.os;
const ip = std.x.net.ip;

const fmt = std.fmt;
const mem = std.mem;
const testing = std.testing;
const native_os = builtin.os;

const IPv4 = std.x.os.IPv4;
const IPv6 = std.x.os.IPv6;
const Socket = std.x.os.Socket;
const Buffer = std.x.os.Buffer;

/// A generic TCP socket abstraction.
const tcp = @This();

/// A TCP client-address pair.
pub const Connection = struct {
    client: tcp.Client,
    address: ip.Address,

    /// Enclose a TCP client and address into a client-address pair.
    pub fn from(conn: Socket.Connection) tcp.Connection {
        return .{
            .client = tcp.Client.from(conn.socket),
            .address = ip.Address.from(conn.address),
        };
    }

    /// Unravel a TCP client-address pair into a socket-address pair.
    pub fn into(self: tcp.Connection) Socket.Connection {
        return .{
            .socket = self.client.socket,
            .address = self.address.into(),
        };
    }

    /// Closes the underlying client of the connection.
    pub fn deinit(self: tcp.Connection) void {
        self.client.deinit();
    }
};

/// Possible domains that a TCP client/listener may operate over.
pub const Domain = enum(u16) {
    ip = os.AF.INET,
    ipv6 = os.AF.INET6,
};

/// A TCP client.
pub const Client = struct {
    socket: Socket,

    /// Implements `std.io.Reader`.
    pub const Reader = struct {
        client: Client,
        flags: u32,

        /// Implements `readFn` for `std.io.Reader`.
        pub fn read(self: Client.Reader, buffer: []u8) !usize {
            return self.client.read(buffer, self.flags);
        }
    };

    /// Implements `std.io.Writer`.
    pub const Writer = struct {
        client: Client,
        flags: u32,

        /// Implements `writeFn` for `std.io.Writer`.
        pub fn write(self: Client.Writer, buffer: []const u8) !usize {
            return self.client.write(buffer, self.flags);
        }
    };

    /// Opens a new client.
    pub fn init(domain: tcp.Domain, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Client {
        return Client{
            .socket = try Socket.init(
                @enumToInt(domain),
                os.SOCK.STREAM,
                os.IPPROTO.TCP,
                flags,
            ),
        };
    }

    /// Enclose a TCP client over an existing socket.
    pub fn from(socket: Socket) Client {
        return Client{ .socket = socket };
    }

    /// Closes the client.
    pub fn deinit(self: Client) void {
        self.socket.deinit();
    }

    /// Shutdown either the read side, write side, or all sides of the client's underlying socket.
    pub fn shutdown(self: Client, how: os.ShutdownHow) !void {
        return self.socket.shutdown(how);
    }

    /// Have the client attempt to the connect to an address.
    pub fn connect(self: Client, address: ip.Address) !void {
        return self.socket.connect(address.into());
    }

    /// Extracts the error set of a function.
    /// TODO: remove after Socket.{read, write} error unions are well-defined across different platforms
    fn ErrorSetOf(comptime Function: anytype) type {
        return @typeInfo(@typeInfo(@TypeOf(Function)).Fn.return_type.?).ErrorUnion.error_set;
    }

    /// Wrap `tcp.Client` into `std.io.Reader`.
    pub fn reader(self: Client, flags: u32) io.Reader(Client.Reader, ErrorSetOf(Client.Reader.read), Client.Reader.read) {
        return .{ .context = .{ .client = self, .flags = flags } };
    }

    /// Wrap `tcp.Client` into `std.io.Writer`.
    pub fn writer(self: Client, flags: u32) io.Writer(Client.Writer, ErrorSetOf(Client.Writer.write), Client.Writer.write) {
        return .{ .context = .{ .client = self, .flags = flags } };
    }

    /// Read data from the socket into the buffer provided with a set of flags
    /// specified. It returns the number of bytes read into the buffer provided.
    pub fn read(self: Client, buf: []u8, flags: u32) !usize {
        return self.socket.read(buf, flags);
    }

    /// Write a buffer of data provided to the socket with a set of flags specified.
    /// It returns the number of bytes that are written to the socket.
    pub fn write(self: Client, buf: []const u8, flags: u32) !usize {
        return self.socket.write(buf, flags);
    }

    /// Writes multiple I/O vectors with a prepended message header to the socket
    /// with a set of flags specified. It returns the number of bytes that are
    /// written to the socket.
    pub fn writeMessage(self: Client, msg: Socket.Message, flags: u32) !usize {
        return self.socket.writeMessage(msg, flags);
    }

    /// Read multiple I/O vectors with a prepended message header from the socket
    /// with a set of flags specified. It returns the number of bytes that were
    /// read into the buffer provided.
    pub fn readMessage(self: Client, msg: *Socket.Message, flags: u32) !usize {
        return self.socket.readMessage(msg, flags);
    }

    /// Query and return the latest cached error on the client's underlying socket.
    pub fn getError(self: Client) !void {
        return self.socket.getError();
    }

    /// Query the read buffer size of the client's underlying socket.
    pub fn getReadBufferSize(self: Client) !u32 {
        return self.socket.getReadBufferSize();
    }

    /// Query the write buffer size of the client's underlying socket.
    pub fn getWriteBufferSize(self: Client) !u32 {
        return self.socket.getWriteBufferSize();
    }

    /// Query the address that the client's socket is locally bounded to.
    pub fn getLocalAddress(self: Client) !ip.Address {
        return ip.Address.from(try self.socket.getLocalAddress());
    }

    /// Query the address that the socket is connected to.
    pub fn getRemoteAddress(self: Client) !ip.Address {
        return ip.Address.from(try self.socket.getRemoteAddress());
    }

    /// Have close() or shutdown() syscalls block until all queued messages in the client have been successfully
    /// sent, or if the timeout specified in seconds has been reached. It returns `error.UnsupportedSocketOption`
    /// if the host does not support the option for a socket to linger around up until a timeout specified in
    /// seconds.
    pub fn setLinger(self: Client, timeout_seconds: ?u16) !void {
        return self.socket.setLinger(timeout_seconds);
    }

    /// Have keep-alive messages be sent periodically. The timing in which keep-alive messages are sent are
    /// dependant on operating system settings. It returns `error.UnsupportedSocketOption` if the host does
    /// not support periodically sending keep-alive messages on connection-oriented sockets.
    pub fn setKeepAlive(self: Client, enabled: bool) !void {
        return self.socket.setKeepAlive(enabled);
    }

    /// Disable Nagle's algorithm on a TCP socket. It returns `error.UnsupportedSocketOption` if
    /// the host does not support sockets disabling Nagle's algorithm.
    pub fn setNoDelay(self: Client, enabled: bool) !void {
        if (@hasDecl(os.TCP, "NODELAY")) {
            const bytes = mem.asBytes(&@as(usize, @boolToInt(enabled)));
            return self.socket.setOption(os.IPPROTO.TCP, os.TCP.NODELAY, bytes);
        }
        return error.UnsupportedSocketOption;
    }

    /// Enables TCP Quick ACK on a TCP socket to immediately send rather than delay ACKs when necessary. It returns
    /// `error.UnsupportedSocketOption` if the host does not support TCP Quick ACK.
    pub fn setQuickACK(self: Client, enabled: bool) !void {
        if (@hasDecl(os.TCP, "QUICKACK")) {
            return self.socket.setOption(os.IPPROTO.TCP, os.TCP.QUICKACK, mem.asBytes(&@as(u32, @boolToInt(enabled))));
        }
        return error.UnsupportedSocketOption;
    }

    /// Set the write buffer size of the socket.
    pub fn setWriteBufferSize(self: Client, size: u32) !void {
        return self.socket.setWriteBufferSize(size);
    }

    /// Set the read buffer size of the socket.
    pub fn setReadBufferSize(self: Client, size: u32) !void {
        return self.socket.setReadBufferSize(size);
    }

    /// Set a timeout on the socket that is to occur if no messages are successfully written
    /// to its bound destination after a specified number of milliseconds. A subsequent write
    /// to the socket will thereafter return `error.WouldBlock` should the timeout be exceeded.
    pub fn setWriteTimeout(self: Client, milliseconds: u32) !void {
        return self.socket.setWriteTimeout(milliseconds);
    }

    /// Set a timeout on the socket that is to occur if no messages are successfully read
    /// from its bound destination after a specified number of milliseconds. A subsequent
    /// read from the socket will thereafter return `error.WouldBlock` should the timeout be
    /// exceeded.
    pub fn setReadTimeout(self: Client, milliseconds: u32) !void {
        return self.socket.setReadTimeout(milliseconds);
    }
};

/// A TCP listener.
pub const Listener = struct {
    socket: Socket,

    /// Opens a new listener.
    pub fn init(domain: tcp.Domain, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Listener {
        return Listener{
            .socket = try Socket.init(
                @enumToInt(domain),
                os.SOCK.STREAM,
                os.IPPROTO.TCP,
                flags,
            ),
        };
    }

    /// Closes the listener.
    pub fn deinit(self: Listener) void {
        self.socket.deinit();
    }

    /// Shuts down the underlying listener's socket. The next subsequent call, or
    /// a current pending call to accept() after shutdown is called will return
    /// an error.
    pub fn shutdown(self: Listener) !void {
        return self.socket.shutdown(.recv);
    }

    /// Binds the listener's socket to an address.
    pub fn bind(self: Listener, address: ip.Address) !void {
        return self.socket.bind(address.into());
    }

    /// Start listening for incoming connections.
    pub fn listen(self: Listener, max_backlog_size: u31) !void {
        return self.socket.listen(max_backlog_size);
    }

    /// Accept a pending incoming connection queued to the kernel backlog
    /// of the listener's socket.
    pub fn accept(self: Listener, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !tcp.Connection {
        return tcp.Connection.from(try self.socket.accept(flags));
    }

    /// Query and return the latest cached error on the listener's underlying socket.
    pub fn getError(self: Client) !void {
        return self.socket.getError();
    }

    /// Query the address that the listener's socket is locally bounded to.
    pub fn getLocalAddress(self: Listener) !ip.Address {
        return ip.Address.from(try self.socket.getLocalAddress());
    }

    /// Allow multiple sockets on the same host to listen on the same address. It returns `error.UnsupportedSocketOption` if
    /// the host does not support sockets listening the same address.
    pub fn setReuseAddress(self: Listener, enabled: bool) !void {
        return self.socket.setReuseAddress(enabled);
    }

    /// Allow multiple sockets on the same host to listen on the same port. It returns `error.UnsupportedSocketOption` if
    /// the host does not supports sockets listening on the same port.
    pub fn setReusePort(self: Listener, enabled: bool) !void {
        return self.socket.setReusePort(enabled);
    }

    /// Enables TCP Fast Open (RFC 7413) on a TCP socket. It returns `error.UnsupportedSocketOption` if the host does not
    /// support TCP Fast Open.
    pub fn setFastOpen(self: Listener, enabled: bool) !void {
        if (@hasDecl(os.TCP, "FASTOPEN")) {
            return self.socket.setOption(os.IPPROTO.TCP, os.TCP.FASTOPEN, mem.asBytes(&@as(u32, @boolToInt(enabled))));
        }
        return error.UnsupportedSocketOption;
    }

    /// Set a timeout on the listener that is to occur if no new incoming connections come in
    /// after a specified number of milliseconds. A subsequent accept call to the listener
    /// will thereafter return `error.WouldBlock` should the timeout be exceeded.
    pub fn setAcceptTimeout(self: Listener, milliseconds: usize) !void {
        return self.socket.setReadTimeout(milliseconds);
    }
};

test "tcp: create client/listener pair" {
    if (native_os.tag == .wasi) return error.SkipZigTest;

    const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
    defer listener.deinit();

    try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
    try listener.listen(128);

    var binded_address = try listener.getLocalAddress();
    switch (binded_address) {
        .ipv4 => |*ipv4| ipv4.host = IPv4.localhost,
        .ipv6 => |*ipv6| ipv6.host = IPv6.localhost,
    }

    const client = try tcp.Client.init(.ip, .{ .close_on_exec = true });
    defer client.deinit();

    try client.connect(binded_address);

    const conn = try listener.accept(.{ .close_on_exec = true });
    defer conn.deinit();
}

test "tcp/client: 1ms read timeout" {
    if (native_os.tag == .wasi) return error.SkipZigTest;

    const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
    defer listener.deinit();

    try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
    try listener.listen(128);

    var binded_address = try listener.getLocalAddress();
    switch (binded_address) {
        .ipv4 => |*ipv4| ipv4.host = IPv4.localhost,
        .ipv6 => |*ipv6| ipv6.host = IPv6.localhost,
    }

    const client = try tcp.Client.init(.ip, .{ .close_on_exec = true });
    defer client.deinit();

    try client.connect(binded_address);
    try client.setReadTimeout(1);

    const conn = try listener.accept(.{ .close_on_exec = true });
    defer conn.deinit();

    var buf: [1]u8 = undefined;
    try testing.expectError(error.WouldBlock, client.reader(0).read(&buf));
}

test "tcp/client: read and write multiple vectors" {
    if (native_os.tag == .wasi) return error.SkipZigTest;

    const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
    defer listener.deinit();

    try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
    try listener.listen(128);

    var binded_address = try listener.getLocalAddress();
    switch (binded_address) {
        .ipv4 => |*ipv4| ipv4.host = IPv4.localhost,
        .ipv6 => |*ipv6| ipv6.host = IPv6.localhost,
    }

    const client = try tcp.Client.init(.ip, .{ .close_on_exec = true });
    defer client.deinit();

    try client.connect(binded_address);

    const conn = try listener.accept(.{ .close_on_exec = true });
    defer conn.deinit();

    const message = "hello world";
    _ = try conn.client.writeMessage(Socket.Message.fromBuffers(&[_]Buffer{
        Buffer.from(message[0 .. message.len / 2]),
        Buffer.from(message[message.len / 2 ..]),
    }), 0);

    var buf: [message.len + 1]u8 = undefined;
    var msg = Socket.Message.fromBuffers(&[_]Buffer{
        Buffer.from(buf[0 .. message.len / 2]),
        Buffer.from(buf[message.len / 2 ..]),
    });
    _ = try client.readMessage(&msg, 0);

    try testing.expectEqualStrings(message, buf[0..message.len]);
}

test "tcp/listener: bind to unspecified ipv4 address" {
    if (native_os.tag == .wasi) return error.SkipZigTest;

    const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
    defer listener.deinit();

    try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
    try listener.listen(128);

    const address = try listener.getLocalAddress();
    try testing.expect(address == .ipv4);
}

test "tcp/listener: bind to unspecified ipv6 address" {
    if (native_os.tag == .wasi) return error.SkipZigTest;

    const listener = try tcp.Listener.init(.ipv6, .{ .close_on_exec = true });
    defer listener.deinit();

    try listener.bind(ip.Address.initIPv6(IPv6.unspecified, 0));
    try listener.listen(128);

    const address = try listener.getLocalAddress();
    try testing.expect(address == .ipv6);
}