const std = @import("../../std.zig");
const builtin = @import("builtin");
const io = std.io;
const os = std.os;
const ip = std.x.net.ip;
const fmt = std.fmt;
const mem = std.mem;
const testing = std.testing;
const native_os = builtin.os;
const IPv4 = std.x.os.IPv4;
const IPv6 = std.x.os.IPv6;
const Socket = std.x.os.Socket;
const Buffer = std.x.os.Buffer;
const tcp = @This();
pub const Connection = struct {
client: tcp.Client,
address: ip.Address,
pub fn from(conn: Socket.Connection) tcp.Connection {
return .{
.client = tcp.Client.from(conn.socket),
.address = ip.Address.from(conn.address),
};
}
pub fn into(self: tcp.Connection) Socket.Connection {
return .{
.socket = self.client.socket,
.address = self.address.into(),
};
}
pub fn deinit(self: tcp.Connection) void {
self.client.deinit();
}
};
pub const Domain = enum(u16) {
ip = os.AF.INET,
ipv6 = os.AF.INET6,
};
pub const Client = struct {
socket: Socket,
pub const Reader = struct {
client: Client,
flags: u32,
pub fn read(self: Client.Reader, buffer: []u8) !usize {
return self.client.read(buffer, self.flags);
}
};
pub const Writer = struct {
client: Client,
flags: u32,
pub fn write(self: Client.Writer, buffer: []const u8) !usize {
return self.client.write(buffer, self.flags);
}
};
pub fn init(domain: tcp.Domain, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Client {
return Client{
.socket = try Socket.init(
@enumToInt(domain),
os.SOCK.STREAM,
os.IPPROTO.TCP,
flags,
),
};
}
pub fn from(socket: Socket) Client {
return Client{ .socket = socket };
}
pub fn deinit(self: Client) void {
self.socket.deinit();
}
pub fn shutdown(self: Client, how: os.ShutdownHow) !void {
return self.socket.shutdown(how);
}
pub fn connect(self: Client, address: ip.Address) !void {
return self.socket.connect(address.into());
}
fn ErrorSetOf(comptime Function: anytype) type {
return @typeInfo(@typeInfo(@TypeOf(Function)).Fn.return_type.?).ErrorUnion.error_set;
}
pub fn reader(self: Client, flags: u32) io.Reader(Client.Reader, ErrorSetOf(Client.Reader.read), Client.Reader.read) {
return .{ .context = .{ .client = self, .flags = flags } };
}
pub fn writer(self: Client, flags: u32) io.Writer(Client.Writer, ErrorSetOf(Client.Writer.write), Client.Writer.write) {
return .{ .context = .{ .client = self, .flags = flags } };
}
pub fn read(self: Client, buf: []u8, flags: u32) !usize {
return self.socket.read(buf, flags);
}
pub fn write(self: Client, buf: []const u8, flags: u32) !usize {
return self.socket.write(buf, flags);
}
pub fn writeMessage(self: Client, msg: Socket.Message, flags: u32) !usize {
return self.socket.writeMessage(msg, flags);
}
pub fn readMessage(self: Client, msg: *Socket.Message, flags: u32) !usize {
return self.socket.readMessage(msg, flags);
}
pub fn getError(self: Client) !void {
return self.socket.getError();
}
pub fn getReadBufferSize(self: Client) !u32 {
return self.socket.getReadBufferSize();
}
pub fn getWriteBufferSize(self: Client) !u32 {
return self.socket.getWriteBufferSize();
}
pub fn getLocalAddress(self: Client) !ip.Address {
return ip.Address.from(try self.socket.getLocalAddress());
}
pub fn getRemoteAddress(self: Client) !ip.Address {
return ip.Address.from(try self.socket.getRemoteAddress());
}
pub fn setLinger(self: Client, timeout_seconds: ?u16) !void {
return self.socket.setLinger(timeout_seconds);
}
pub fn setKeepAlive(self: Client, enabled: bool) !void {
return self.socket.setKeepAlive(enabled);
}
pub fn setNoDelay(self: Client, enabled: bool) !void {
if (@hasDecl(os.TCP, "NODELAY")) {
const bytes = mem.asBytes(&@as(usize, @boolToInt(enabled)));
return self.socket.setOption(os.IPPROTO.TCP, os.TCP.NODELAY, bytes);
}
return error.UnsupportedSocketOption;
}
pub fn setQuickACK(self: Client, enabled: bool) !void {
if (@hasDecl(os.TCP, "QUICKACK")) {
return self.socket.setOption(os.IPPROTO.TCP, os.TCP.QUICKACK, mem.asBytes(&@as(u32, @boolToInt(enabled))));
}
return error.UnsupportedSocketOption;
}
pub fn setWriteBufferSize(self: Client, size: u32) !void {
return self.socket.setWriteBufferSize(size);
}
pub fn setReadBufferSize(self: Client, size: u32) !void {
return self.socket.setReadBufferSize(size);
}
pub fn setWriteTimeout(self: Client, milliseconds: u32) !void {
return self.socket.setWriteTimeout(milliseconds);
}
pub fn setReadTimeout(self: Client, milliseconds: u32) !void {
return self.socket.setReadTimeout(milliseconds);
}
};
pub const Listener = struct {
socket: Socket,
pub fn init(domain: tcp.Domain, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !Listener {
return Listener{
.socket = try Socket.init(
@enumToInt(domain),
os.SOCK.STREAM,
os.IPPROTO.TCP,
flags,
),
};
}
pub fn deinit(self: Listener) void {
self.socket.deinit();
}
pub fn shutdown(self: Listener) !void {
return self.socket.shutdown(.recv);
}
pub fn bind(self: Listener, address: ip.Address) !void {
return self.socket.bind(address.into());
}
pub fn listen(self: Listener, max_backlog_size: u31) !void {
return self.socket.listen(max_backlog_size);
}
pub fn accept(self: Listener, flags: std.enums.EnumFieldStruct(Socket.InitFlags, bool, false)) !tcp.Connection {
return tcp.Connection.from(try self.socket.accept(flags));
}
pub fn getError(self: Client) !void {
return self.socket.getError();
}
pub fn getLocalAddress(self: Listener) !ip.Address {
return ip.Address.from(try self.socket.getLocalAddress());
}
pub fn setReuseAddress(self: Listener, enabled: bool) !void {
return self.socket.setReuseAddress(enabled);
}
pub fn setReusePort(self: Listener, enabled: bool) !void {
return self.socket.setReusePort(enabled);
}
pub fn setFastOpen(self: Listener, enabled: bool) !void {
if (@hasDecl(os.TCP, "FASTOPEN")) {
return self.socket.setOption(os.IPPROTO.TCP, os.TCP.FASTOPEN, mem.asBytes(&@as(u32, @boolToInt(enabled))));
}
return error.UnsupportedSocketOption;
}
pub fn setAcceptTimeout(self: Listener, milliseconds: usize) !void {
return self.socket.setReadTimeout(milliseconds);
}
};
test "tcp: create client/listener pair" {
if (native_os.tag == .wasi) return error.SkipZigTest;
const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
defer listener.deinit();
try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
try listener.listen(128);
var binded_address = try listener.getLocalAddress();
switch (binded_address) {
.ipv4 => |*ipv4| ipv4.host = IPv4.localhost,
.ipv6 => |*ipv6| ipv6.host = IPv6.localhost,
}
const client = try tcp.Client.init(.ip, .{ .close_on_exec = true });
defer client.deinit();
try client.connect(binded_address);
const conn = try listener.accept(.{ .close_on_exec = true });
defer conn.deinit();
}
test "tcp/client: 1ms read timeout" {
if (native_os.tag == .wasi) return error.SkipZigTest;
const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
defer listener.deinit();
try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
try listener.listen(128);
var binded_address = try listener.getLocalAddress();
switch (binded_address) {
.ipv4 => |*ipv4| ipv4.host = IPv4.localhost,
.ipv6 => |*ipv6| ipv6.host = IPv6.localhost,
}
const client = try tcp.Client.init(.ip, .{ .close_on_exec = true });
defer client.deinit();
try client.connect(binded_address);
try client.setReadTimeout(1);
const conn = try listener.accept(.{ .close_on_exec = true });
defer conn.deinit();
var buf: [1]u8 = undefined;
try testing.expectError(error.WouldBlock, client.reader(0).read(&buf));
}
test "tcp/client: read and write multiple vectors" {
if (native_os.tag == .wasi) return error.SkipZigTest;
const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
defer listener.deinit();
try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
try listener.listen(128);
var binded_address = try listener.getLocalAddress();
switch (binded_address) {
.ipv4 => |*ipv4| ipv4.host = IPv4.localhost,
.ipv6 => |*ipv6| ipv6.host = IPv6.localhost,
}
const client = try tcp.Client.init(.ip, .{ .close_on_exec = true });
defer client.deinit();
try client.connect(binded_address);
const conn = try listener.accept(.{ .close_on_exec = true });
defer conn.deinit();
const message = "hello world";
_ = try conn.client.writeMessage(Socket.Message.fromBuffers(&[_]Buffer{
Buffer.from(message[0 .. message.len / 2]),
Buffer.from(message[message.len / 2 ..]),
}), 0);
var buf: [message.len + 1]u8 = undefined;
var msg = Socket.Message.fromBuffers(&[_]Buffer{
Buffer.from(buf[0 .. message.len / 2]),
Buffer.from(buf[message.len / 2 ..]),
});
_ = try client.readMessage(&msg, 0);
try testing.expectEqualStrings(message, buf[0..message.len]);
}
test "tcp/listener: bind to unspecified ipv4 address" {
if (native_os.tag == .wasi) return error.SkipZigTest;
const listener = try tcp.Listener.init(.ip, .{ .close_on_exec = true });
defer listener.deinit();
try listener.bind(ip.Address.initIPv4(IPv4.unspecified, 0));
try listener.listen(128);
const address = try listener.getLocalAddress();
try testing.expect(address == .ipv4);
}
test "tcp/listener: bind to unspecified ipv6 address" {
if (native_os.tag == .wasi) return error.SkipZigTest;
const listener = try tcp.Listener.init(.ipv6, .{ .close_on_exec = true });
defer listener.deinit();
try listener.bind(ip.Address.initIPv6(IPv6.unspecified, 0));
try listener.listen(128);
const address = try listener.getLocalAddress();
try testing.expect(address == .ipv6);
}