const std = @import("../std.zig");
const builtin = @import("builtin");
const assert = std.debug.assert;
const testing = std.testing;
const Loop = std.event.Loop;

/// Many producer, many consumer, thread-safe, runtime configurable buffer size.
/// When buffer is empty, consumers suspend and are resumed by producers.
/// When buffer is full, producers suspend and are resumed by consumers.
pub fn Channel(comptime T: type) type {
    return struct {
        getters: std.atomic.Queue(GetNode),
        or_null_queue: std.atomic.Queue(*std.atomic.Queue(GetNode).Node),
        putters: std.atomic.Queue(PutNode),
        get_count: usize,
        put_count: usize,
        dispatch_lock: bool,
        need_dispatch: bool,

        // simple fixed size ring buffer

        buffer_nodes: []T,
        buffer_index: usize,
        buffer_len: usize,

        const SelfChannel = @This();
        const GetNode = struct {
            tick_node: *Loop.NextTickNode,
            data: Data,

            const Data = union(enum) {
                Normal: Normal,
                OrNull: OrNull,
            };

            const Normal = struct {
                ptr: *T,
            };

            const OrNull = struct {
                ptr: *?T,
                or_null: *std.atomic.Queue(*std.atomic.Queue(GetNode).Node).Node,
            };
        };
        const PutNode = struct {
            data: T,
            tick_node: *Loop.NextTickNode,
        };

        const global_event_loop = Loop.instance orelse
            @compileError("std.event.Channel currently only works with event-based I/O");

        /// Call `deinit` to free resources when done.
        /// `buffer` must live until `deinit` is called.
        /// For a zero length buffer, use `[0]T{}`.
        /// TODO https://github.com/ziglang/zig/issues/2765
        pub fn init(self: *SelfChannel, buffer: []T) void {
            // The ring buffer implementation only works with power of 2 buffer sizes

            // because of relying on subtracting across zero. For example (0 -% 1) % 10 == 5

            assert(buffer.len == 0 or @popCount(buffer.len) == 1);

            self.* = SelfChannel{
                .buffer_len = 0,
                .buffer_nodes = buffer,
                .buffer_index = 0,
                .dispatch_lock = false,
                .need_dispatch = false,
                .getters = std.atomic.Queue(GetNode).init(),
                .putters = std.atomic.Queue(PutNode).init(),
                .or_null_queue = std.atomic.Queue(*std.atomic.Queue(GetNode).Node).init(),
                .get_count = 0,
                .put_count = 0,
            };
        }

        /// Must be called when all calls to put and get have suspended and no more calls occur.
        /// This can be omitted if caller can guarantee that the suspended putters and getters
        /// do not need to be run to completion. Note that this may leave awaiters hanging.
        pub fn deinit(self: *SelfChannel) void {
            while (self.getters.get()) |get_node| {
                resume get_node.data.tick_node.data;
            }
            while (self.putters.get()) |put_node| {
                resume put_node.data.tick_node.data;
            }
            self.* = undefined;
        }

        /// puts a data item in the channel. The function returns when the value has been added to the
        /// buffer, or in the case of a zero size buffer, when the item has been retrieved by a getter.
        /// Or when the channel is destroyed.
        pub fn put(self: *SelfChannel, data: T) void {
            var my_tick_node = Loop.NextTickNode{ .data = @frame() };
            var queue_node = std.atomic.Queue(PutNode).Node{
                .data = PutNode{
                    .tick_node = &my_tick_node,
                    .data = data,
                },
            };

            suspend {
                self.putters.put(&queue_node);
                _ = @atomicRmw(usize, &self.put_count, .Add, 1, .SeqCst);

                self.dispatch();
            }
        }

        /// await this function to get an item from the channel. If the buffer is empty, the frame will
        /// complete when the next item is put in the channel.
        pub fn get(self: *SelfChannel) callconv(.Async) T {
            // TODO https://github.com/ziglang/zig/issues/2765

            var result: T = undefined;
            var my_tick_node = Loop.NextTickNode{ .data = @frame() };
            var queue_node = std.atomic.Queue(GetNode).Node{
                .data = GetNode{
                    .tick_node = &my_tick_node,
                    .data = GetNode.Data{
                        .Normal = GetNode.Normal{ .ptr = &result },
                    },
                },
            };

            suspend {
                self.getters.put(&queue_node);
                _ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst);

                self.dispatch();
            }
            return result;
        }

        //pub async fn select(comptime EnumUnion: type, channels: ...) EnumUnion {

        //    assert(@memberCount(EnumUnion) == channels.len); // enum union and channels mismatch

        //    assert(channels.len != 0); // enum unions cannot have 0 fields

        //    if (channels.len == 1) {

        //        const result = await (async channels[0].get() catch unreachable);

        //        return @unionInit(EnumUnion, @memberName(EnumUnion, 0), result);

        //    }

        //}


        /// Get an item from the channel. If the buffer is empty and there are no
        /// puts waiting, this returns `null`.
        pub fn getOrNull(self: *SelfChannel) ?T {
            // TODO integrate this function with named return values

            // so we can get rid of this extra result copy

            var result: ?T = null;
            var my_tick_node = Loop.NextTickNode{ .data = @frame() };
            var or_null_node = std.atomic.Queue(*std.atomic.Queue(GetNode).Node).Node{ .data = undefined };
            var queue_node = std.atomic.Queue(GetNode).Node{
                .data = GetNode{
                    .tick_node = &my_tick_node,
                    .data = GetNode.Data{
                        .OrNull = GetNode.OrNull{
                            .ptr = &result,
                            .or_null = &or_null_node,
                        },
                    },
                },
            };
            or_null_node.data = &queue_node;

            suspend {
                self.getters.put(&queue_node);
                _ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst);
                self.or_null_queue.put(&or_null_node);

                self.dispatch();
            }
            return result;
        }

        fn dispatch(self: *SelfChannel) void {
            // set the "need dispatch" flag

            @atomicStore(bool, &self.need_dispatch, true, .SeqCst);

            lock: while (true) {
                // set the lock flag

                if (@atomicRmw(bool, &self.dispatch_lock, .Xchg, true, .SeqCst)) return;

                // clear the need_dispatch flag since we're about to do it

                @atomicStore(bool, &self.need_dispatch, false, .SeqCst);

                while (true) {
                    one_dispatch: {
                        // later we correct these extra subtractions

                        var get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst);
                        var put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst);

                        // transfer self.buffer to self.getters

                        while (self.buffer_len != 0) {
                            if (get_count == 0) break :one_dispatch;

                            const get_node = &self.getters.get().?.data;
                            switch (get_node.data) {
                                GetNode.Data.Normal => |info| {
                                    info.ptr.* = self.buffer_nodes[(self.buffer_index -% self.buffer_len) % self.buffer_nodes.len];
                                },
                                GetNode.Data.OrNull => |info| {
                                    _ = self.or_null_queue.remove(info.or_null);
                                    info.ptr.* = self.buffer_nodes[(self.buffer_index -% self.buffer_len) % self.buffer_nodes.len];
                                },
                            }
                            global_event_loop.onNextTick(get_node.tick_node);
                            self.buffer_len -= 1;

                            get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst);
                        }

                        // direct transfer self.putters to self.getters

                        while (get_count != 0 and put_count != 0) {
                            const get_node = &self.getters.get().?.data;
                            const put_node = &self.putters.get().?.data;

                            switch (get_node.data) {
                                GetNode.Data.Normal => |info| {
                                    info.ptr.* = put_node.data;
                                },
                                GetNode.Data.OrNull => |info| {
                                    _ = self.or_null_queue.remove(info.or_null);
                                    info.ptr.* = put_node.data;
                                },
                            }
                            global_event_loop.onNextTick(get_node.tick_node);
                            global_event_loop.onNextTick(put_node.tick_node);

                            get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst);
                            put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst);
                        }

                        // transfer self.putters to self.buffer

                        while (self.buffer_len != self.buffer_nodes.len and put_count != 0) {
                            const put_node = &self.putters.get().?.data;

                            self.buffer_nodes[self.buffer_index % self.buffer_nodes.len] = put_node.data;
                            global_event_loop.onNextTick(put_node.tick_node);
                            self.buffer_index +%= 1;
                            self.buffer_len += 1;

                            put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst);
                        }
                    }

                    // undo the extra subtractions

                    _ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst);
                    _ = @atomicRmw(usize, &self.put_count, .Add, 1, .SeqCst);

                    // All the "get or null" functions should resume now.

                    var remove_count: usize = 0;
                    while (self.or_null_queue.get()) |or_null_node| {
                        remove_count += @boolToInt(self.getters.remove(or_null_node.data));
                        global_event_loop.onNextTick(or_null_node.data.data.tick_node);
                    }
                    if (remove_count != 0) {
                        _ = @atomicRmw(usize, &self.get_count, .Sub, remove_count, .SeqCst);
                    }

                    // clear need-dispatch flag

                    if (@atomicRmw(bool, &self.need_dispatch, .Xchg, false, .SeqCst)) continue;

                    assert(@atomicRmw(bool, &self.dispatch_lock, .Xchg, false, .SeqCst));

                    // we have to check again now that we unlocked

                    if (@atomicLoad(bool, &self.need_dispatch, .SeqCst)) continue :lock;

                    return;
                }
            }
        }
    };
}

test "std.event.Channel" {
    if (!std.io.is_async) return error.SkipZigTest;

    // https://github.com/ziglang/zig/issues/1908

    if (builtin.single_threaded) return error.SkipZigTest;

    // https://github.com/ziglang/zig/issues/3251

    if (builtin.os.tag == .freebsd) return error.SkipZigTest;

    var channel: Channel(i32) = undefined;
    channel.init(&[0]i32{});
    defer channel.deinit();

    var handle = async testChannelGetter(&channel);
    var putter = async testChannelPutter(&channel);

    await handle;
    await putter;
}

test "std.event.Channel wraparound" {

    // TODO provide a way to run tests in evented I/O mode

    if (!std.io.is_async) return error.SkipZigTest;

    const channel_size = 2;

    var buf: [channel_size]i32 = undefined;
    var channel: Channel(i32) = undefined;
    channel.init(&buf);
    defer channel.deinit();

    // add items to channel and pull them out until

    // the buffer wraps around, make sure it doesn't crash.

    channel.put(5);
    try testing.expectEqual(@as(i32, 5), channel.get());
    channel.put(6);
    try testing.expectEqual(@as(i32, 6), channel.get());
    channel.put(7);
    try testing.expectEqual(@as(i32, 7), channel.get());
}
fn testChannelGetter(channel: *Channel(i32)) callconv(.Async) void {
    const value1 = channel.get();
    try testing.expect(value1 == 1234);

    const value2 = channel.get();
    try testing.expect(value2 == 4567);

    const value3 = channel.getOrNull();
    try testing.expect(value3 == null);

    var last_put = async testPut(channel, 4444);
    const value4 = channel.getOrNull();
    try testing.expect(value4.? == 4444);
    await last_put;
}
fn testChannelPutter(channel: *Channel(i32)) callconv(.Async) void {
    channel.put(1234);
    channel.put(4567);
}
fn testPut(channel: *Channel(i32), value: i32) callconv(.Async) void {
    channel.put(value);
}