// Copyright (C) 2023 Adam Lugowski. All rights reserved.
// Use of this source code is governed by:
// the BSD 2-clause license, the MIT license, or at your choosing the BSL-1.0 license found in the LICENSE.*.txt files.
// SPDX-License-Identifier: BSD-2-Clause OR MIT OR BSL-1.0


#ifndef POOLSTL_EXECUTION_HPP
#define POOLSTL_EXECUTION_HPP

#include <memory>
#include <mutex>
#include <stdexcept>
#include <type_traits>

#include "internal/task_thread_pool.hpp"
#include "internal/utils.hpp"

namespace poolstl {

    namespace ttp = task_thread_pool;

    namespace execution {
        namespace internal {
            /**
             * Holds the thread pool used by par.
             */
            inline std::shared_ptr<ttp::task_thread_pool> get_default_pool() {
                static std::shared_ptr<ttp::task_thread_pool> pool;
                static std::once_flag flag;
                std::call_once(flag, [&](){ pool = std::make_shared<ttp::task_thread_pool>(); });
                return pool;
            }
        }

        /**
         * Base class for all poolSTL policies.
         */
        struct poolstl_policy {
        };

        /**
         * A sequential policy that simply forwards to the non-policy overload.
         */
        struct sequenced_policy : public poolstl_policy {
            POOLSTL_NO_DISCARD ttp::task_thread_pool* pool() const {
                // never called, but must exist for C++11 support
                throw std::runtime_error("poolSTL: requested thread pool for seq policy.");
            }

            POOLSTL_NO_DISCARD bool par_allowed() const {
                return false;
            }
        };

        /**
         * A parallel policy that can use a user-specified thread pool or a default one.
         */
        struct parallel_policy : public poolstl_policy {
            parallel_policy() = default;
            explicit parallel_policy(ttp::task_thread_pool* on_pool, bool par_ok): on_pool(on_pool), par_ok(par_ok) {}

            parallel_policy on(ttp::task_thread_pool& pool) const {
                return parallel_policy{&pool, par_ok};
            }

            parallel_policy par_if(bool call_par) const {
                return parallel_policy{on_pool, call_par};
            }

            POOLSTL_NO_DISCARD ttp::task_thread_pool* pool() const {
                if (on_pool) {
                    return on_pool;
                } else {
                    return internal::get_default_pool().get();
                }
            }

            POOLSTL_NO_DISCARD bool par_allowed() const {
                return par_ok;
            }

        protected:
            ttp::task_thread_pool *on_pool = nullptr;
            bool par_ok = true;
        };

        constexpr sequenced_policy seq{};
        constexpr parallel_policy par{};

        /**
         * EXPERIMENTAL: Subject to significant changes or removal.
         * Use pure threads for each operation instead of a shared thread pool.
         *
         * Advantage:
         *  - Fewer symbols (no packaged_task with its operators, destructors, vtable, etc) means smaller binary
         *    which can mean a lot when there are many calls.
         *  - No thread pool to manage.
         *
         * Disadvantages:
         *  - Threads are started and joined for every operation, so it is harder to amortize that cost.
         *  - Barely any algorithms are supported.
         */
        struct pure_threads_policy : public poolstl_policy {
            explicit pure_threads_policy(unsigned int num_threads, bool par_ok): num_threads(num_threads),
                                                                                 par_ok(par_ok) {}

            POOLSTL_NO_DISCARD unsigned int get_num_threads() const {
                if (num_threads == 0) {
                    return std::thread::hardware_concurrency();
                }
                return num_threads;
            }

            POOLSTL_NO_DISCARD bool par_allowed() const {
                return par_ok;
            }

        protected:
            unsigned int num_threads = 1;
            bool par_ok = true;
        };

        /**
         * Choose parallel or sequential at runtime.
         *
         * @param call_par Whether to use a parallel policy.
         * @return `par` if call_par is true, else a sequential policy (like `seq`).
         */
        inline parallel_policy par_if(bool call_par) {
            return parallel_policy{nullptr, call_par};
        }

        /**
         * Choose parallel or sequential at runtime, with pool selection.
         *
         * @param call_par Whether to use a parallel policy.
         * @return `par.on(pool)` if call_par is true, else a sequential policy (like `seq`).
         */
        inline parallel_policy par_if(bool call_par, ttp::task_thread_pool& pool) {
            return parallel_policy{&pool, call_par};
        }

        /**
         * EXPERIMENTAL: Subject to significant changes or removal. See `pure_threads_policy`.
         * Choose parallel or sequential at runtime, with thread count selection.
         *
         * @param call_par Whether to use a parallel policy.
         * @return `par.on(pool)` if call_par is true, else `seq`.
         */
        inline pure_threads_policy par_if_threads(bool call_par, unsigned int num_threads) {
            return pure_threads_policy{num_threads, call_par};
        }
    }

    using execution::seq;
    using execution::par;
    using execution::par_if;

    namespace internal {
        /**
         * To enable/disable seq overload resolution
         */
        template <class ExecPolicy, class Tp>
        using enable_if_seq =
            typename std::enable_if<
                std::is_same<poolstl::execution::sequenced_policy,
                    typename std::remove_cv<typename std::remove_reference<ExecPolicy>::type>::type>::value,
                Tp>::type;

        /**
         * To enable/disable par overload resolution
         */
        template <class ExecPolicy, class Tp>
        using enable_if_par =
            typename std::enable_if<
                std::is_same<poolstl::execution::parallel_policy,
                    typename std::remove_cv<typename std::remove_reference<ExecPolicy>::type>::type>::value,
                Tp>::type;

        /**
         * To enable/disable par overload resolution
         */
        template <class ExecPolicy, class Tp>
        using enable_if_poolstl_policy =
            typename std::enable_if<
                std::is_base_of<poolstl::execution::poolstl_policy,
                    typename std::remove_cv<typename std::remove_reference<ExecPolicy>::type>::type>::value,
                Tp>::type;

        template <class ExecPolicy>
        bool is_seq(const ExecPolicy& policy) {
            return !policy.par_allowed();
        }

        template <class ExecPolicy>
        using is_pure_threads_policy = std::is_same<poolstl::execution::pure_threads_policy,
            typename std::remove_cv<typename std::remove_reference<ExecPolicy>::type>::type>;
    }
}

#endif
