// Copyright (C) 2023 Adam Lugowski. All rights reserved.
// Use of this source code is governed by:
// the BSD 2-clause license, the MIT license, or at your choosing the BSL-1.0 license found in the LICENSE.*.txt files.
// SPDX-License-Identifier: BSD-2-Clause OR MIT OR BSL-1.0


#ifndef POOLSTL_NUMERIC_HPP
#define POOLSTL_NUMERIC_HPP

#include <functional>
#include <tuple>

#include "execution"
#include "internal/ttp_impl.hpp"

namespace std {

#if POOLSTL_HAVE_CXX17_LIB
    /**
     * NOTE: Iterators are expected to be random access.
     * See std::exclusive_scan https://en.cppreference.com/w/cpp/algorithm/exclusive_scan
     */
    template <class ExecPolicy, class RandIt1, class RandIt2, class T, class BinaryOp>
    poolstl::internal::enable_if_poolstl_policy<ExecPolicy, RandIt2>
    exclusive_scan(ExecPolicy &&policy, RandIt1 first, RandIt1 last, RandIt2 dest, T init, BinaryOp binop) {
        if (first == last) {
            return dest;
        }

        if (poolstl::internal::is_seq<ExecPolicy>(policy)) {
            return std::exclusive_scan(first, last, dest, init, binop);
        }

        // Pass 1: Chunk the input and find the sum of each chunk
        auto futures = poolstl::internal::parallel_chunk_for_gen(std::forward<ExecPolicy>(policy), first, last,
                             [binop](RandIt1 chunk_first, RandIt1 chunk_last) {
                                 auto sum = std::accumulate(chunk_first, chunk_last, T{}, binop);
                                 return std::make_tuple(std::make_pair(chunk_first, chunk_last), sum);
                             });

        std::vector<std::pair<RandIt1, RandIt1>> ranges;
        std::vector<T> sums;

        for (auto& future : futures) {
            auto res = future.get();
            ranges.push_back(std::get<0>(res));
            sums.push_back(std::get<1>(res));
        }

        // find initial values for each range
        std::exclusive_scan(sums.begin(), sums.end(), sums.begin(), init, binop);

        // Pass 2: perform exclusive scan of each chunk, using the sum of previous chunks as init
        std::vector<std::tuple<RandIt1, RandIt1, RandIt2, T>> args;
        for (std::size_t i = 0; i < sums.size(); ++i) {
            auto chunk_first = std::get<0>(ranges[i]);
            args.emplace_back(std::make_tuple(
                chunk_first, std::get<1>(ranges[i]),
                dest + (chunk_first - first),
                sums[i]));
        }

        auto futures2 = poolstl::internal::parallel_apply(std::forward<ExecPolicy>(policy),
            [binop](RandIt1 chunk_first, RandIt1 chunk_last, RandIt2 chunk_dest, T chunk_init){
                std::exclusive_scan(chunk_first, chunk_last, chunk_dest, chunk_init, binop);
            }, args);

        poolstl::internal::get_futures(futures2);
        return dest + (last - first);
    }

    /**
     * NOTE: Iterators are expected to be random access.
     * See std::exclusive_scan https://en.cppreference.com/w/cpp/algorithm/exclusive_scan
     */
    template <class ExecPolicy, class RandIt1, class RandIt2, class T>
    poolstl::internal::enable_if_poolstl_policy<ExecPolicy, RandIt2>
    exclusive_scan(ExecPolicy &&policy, RandIt1 first, RandIt1 last, RandIt2 dest, T init) {
        return std::exclusive_scan(std::forward<ExecPolicy>(policy), first, last, dest, init, std::plus<T>());
    }
#endif

    /**
     * NOTE: Iterators are expected to be random access.
     * See std::reduce https://en.cppreference.com/w/cpp/algorithm/reduce
     */
    template <class ExecPolicy, class RandIt, class T, class BinaryOp>
    poolstl::internal::enable_if_poolstl_policy<ExecPolicy, T>
    reduce(ExecPolicy &&policy, RandIt first, RandIt last, T init, BinaryOp binop) {
        if (poolstl::internal::is_seq<ExecPolicy>(policy)) {
            return poolstl::internal::cpp17::reduce(first, last, init, binop);
        }

        auto futures = poolstl::internal::parallel_chunk_for_1(std::forward<ExecPolicy>(policy), first, last,
                                                               poolstl::internal::cpp17::reduce<RandIt, T, BinaryOp>,
                                                               (T*)nullptr, 1, init, binop);

        return poolstl::internal::cpp17::reduce(
            poolstl::internal::get_wrap(futures.begin()),
            poolstl::internal::get_wrap(futures.end()), init, binop);
    }

    /**
     * NOTE: Iterators are expected to be random access.
     * See std::reduce https://en.cppreference.com/w/cpp/algorithm/reduce
     */
    template <class ExecPolicy, class RandIt, class T>
    poolstl::internal::enable_if_poolstl_policy<ExecPolicy, T>
    reduce(ExecPolicy &&policy, RandIt first, RandIt last, T init) {
        return std::reduce(std::forward<ExecPolicy>(policy), first, last, init, std::plus<T>());
    }

    /**
     * NOTE: Iterators are expected to be random access.
     * See std::reduce https://en.cppreference.com/w/cpp/algorithm/reduce
     */
    template <class ExecPolicy, class RandIt>
    poolstl::internal::enable_if_poolstl_policy<
        ExecPolicy, typename std::iterator_traits<RandIt>::value_type>
    reduce(ExecPolicy &&policy, RandIt first, RandIt last) {
        return std::reduce(std::forward<ExecPolicy>(policy), first, last,
                           typename std::iterator_traits<RandIt>::value_type{});
    }

#if POOLSTL_HAVE_CXX17_LIB
    /**
     * NOTE: Iterators are expected to be random access.
     * See std::transform_reduce https://en.cppreference.com/w/cpp/algorithm/transform_reduce
     */
    template <class ExecPolicy, class RandIt1, class T, class BinaryReductionOp, class UnaryTransformOp>
    poolstl::internal::enable_if_poolstl_policy<ExecPolicy, T>
    transform_reduce(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, T init,
                     BinaryReductionOp reduce_op, UnaryTransformOp transform_op) {
        if (poolstl::internal::is_seq<ExecPolicy>(policy)) {
            return std::transform_reduce(first1, last1, init, reduce_op, transform_op);
        }

        auto futures = poolstl::internal::parallel_chunk_for_1(std::forward<ExecPolicy>(policy), first1, last1,
                                                               std::transform_reduce<RandIt1, T,
                                                                                   BinaryReductionOp, UnaryTransformOp>,
                                                               (T*)nullptr, 1, init, reduce_op, transform_op);

        return poolstl::internal::cpp17::reduce(
            poolstl::internal::get_wrap(futures.begin()),
            poolstl::internal::get_wrap(futures.end()), init, reduce_op);
    }

    /**
     * NOTE: Iterators are expected to be random access.
     * See std::transform_reduce https://en.cppreference.com/w/cpp/algorithm/transform_reduce
     */
    template <class ExecPolicy, class RandIt1, class RandIt2, class T, class BinaryReductionOp, class BinaryTransformOp>
    poolstl::internal::enable_if_poolstl_policy<ExecPolicy, T>
    transform_reduce(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, T init,
                     BinaryReductionOp reduce_op, BinaryTransformOp transform_op) {
        if (poolstl::internal::is_seq<ExecPolicy>(policy)) {
            return std::transform_reduce(first1, last1, first2, init, reduce_op, transform_op);
        }

        auto futures = poolstl::internal::parallel_chunk_for_2(std::forward<ExecPolicy>(policy), first1, last1, first2,
                                                               std::transform_reduce<RandIt1, RandIt2, T,
                                                                                  BinaryReductionOp, BinaryTransformOp>,
                                                               (T*)nullptr, init, reduce_op, transform_op);

        return poolstl::internal::cpp17::reduce(
            poolstl::internal::get_wrap(futures.begin()),
            poolstl::internal::get_wrap(futures.end()), init, reduce_op);
    }

    /**
     * NOTE: Iterators are expected to be random access.
     * See std::transform_reduce https://en.cppreference.com/w/cpp/algorithm/transform_reduce
     */
    template< class ExecPolicy, class RandIt1, class RandIt2, class T >
    poolstl::internal::enable_if_poolstl_policy<ExecPolicy, T>
    transform_reduce(ExecPolicy&& policy, RandIt1 first1, RandIt1 last1, RandIt2 first2, T init ) {
        return transform_reduce(std::forward<ExecPolicy>(policy),
            first1, last1, first2, init, std::plus<>(), std::multiplies<>());
    }
#endif

}

#endif
