// -*- C++ -*-
//===---------------------------- cmath -----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX_CMATH
#define _LIBCUDACXX_CMATH

/*
    cmath synopsis

Macros:

    HUGE_VAL
    HUGE_VALF               // C99
    HUGE_VALL               // C99
    INFINITY                // C99
    NAN                     // C99
    FP_INFINITE             // C99
    FP_NAN                  // C99
    FP_NORMAL               // C99
    FP_SUBNORMAL            // C99
    FP_ZERO                 // C99
    FP_FAST_FMA             // C99
    FP_FAST_FMAF            // C99
    FP_FAST_FMAL            // C99
    FP_ILOGB0               // C99
    FP_ILOGBNAN             // C99
    MATH_ERRNO              // C99
    MATH_ERREXCEPT          // C99
    math_errhandling        // C99

namespace std
{

Types:

    float_t                 // C99
    double_t                // C99

// C90

floating_point abs(floating_point x);

floating_point acos (arithmetic x);
float          acosf(float x);
long double    acosl(long double x);

floating_point asin (arithmetic x);
float          asinf(float x);
long double    asinl(long double x);

floating_point atan (arithmetic x);
float          atanf(float x);
long double    atanl(long double x);

floating_point atan2 (arithmetic y, arithmetic x);
float          atan2f(float y, float x);
long double    atan2l(long double y, long double x);

floating_point ceil (arithmetic x);
float          ceilf(float x);
long double    ceill(long double x);

floating_point cos (arithmetic x);
float          cosf(float x);
long double    cosl(long double x);

floating_point cosh (arithmetic x);
float          coshf(float x);
long double    coshl(long double x);

floating_point exp (arithmetic x);
float          expf(float x);
long double    expl(long double x);

floating_point fabs (arithmetic x);
float          fabsf(float x);
long double    fabsl(long double x);

floating_point floor (arithmetic x);
float          floorf(float x);
long double    floorl(long double x);

floating_point fmod (arithmetic x, arithmetic y);
float          fmodf(float x, float y);
long double    fmodl(long double x, long double y);

floating_point frexp (arithmetic value, int* exp);
float          frexpf(float value, int* exp);
long double    frexpl(long double value, int* exp);

floating_point ldexp (arithmetic value, int exp);
float          ldexpf(float value, int exp);
long double    ldexpl(long double value, int exp);

floating_point log (arithmetic x);
float          logf(float x);
long double    logl(long double x);

floating_point log10 (arithmetic x);
float          log10f(float x);
long double    log10l(long double x);

floating_point modf (floating_point value, floating_point* iptr);
float          modff(float value, float* iptr);
long double    modfl(long double value, long double* iptr);

floating_point pow (arithmetic x, arithmetic y);
float          powf(float x, float y);
long double    powl(long double x, long double y);

floating_point sin (arithmetic x);
float          sinf(float x);
long double    sinl(long double x);

floating_point sinh (arithmetic x);
float          sinhf(float x);
long double    sinhl(long double x);

floating_point sqrt (arithmetic x);
float          sqrtf(float x);
long double    sqrtl(long double x);

floating_point tan (arithmetic x);
float          tanf(float x);
long double    tanl(long double x);

floating_point tanh (arithmetic x);
float          tanhf(float x);
long double    tanhl(long double x);

//  C99

bool signbit(arithmetic x);

int fpclassify(arithmetic x);

bool isfinite(arithmetic x);
bool isinf(arithmetic x);
bool isnan(arithmetic x);
bool isnormal(arithmetic x);

bool isgreater(arithmetic x, arithmetic y);
bool isgreaterequal(arithmetic x, arithmetic y);
bool isless(arithmetic x, arithmetic y);
bool islessequal(arithmetic x, arithmetic y);
bool islessgreater(arithmetic x, arithmetic y);
bool isunordered(arithmetic x, arithmetic y);

floating_point acosh (arithmetic x);
float          acoshf(float x);
long double    acoshl(long double x);

floating_point asinh (arithmetic x);
float          asinhf(float x);
long double    asinhl(long double x);

floating_point atanh (arithmetic x);
float          atanhf(float x);
long double    atanhl(long double x);

floating_point cbrt (arithmetic x);
float          cbrtf(float x);
long double    cbrtl(long double x);

floating_point copysign (arithmetic x, arithmetic y);
float          copysignf(float x, float y);
long double    copysignl(long double x, long double y);

floating_point erf (arithmetic x);
float          erff(float x);
long double    erfl(long double x);

floating_point erfc (arithmetic x);
float          erfcf(float x);
long double    erfcl(long double x);

floating_point exp2 (arithmetic x);
float          exp2f(float x);
long double    exp2l(long double x);

floating_point expm1 (arithmetic x);
float          expm1f(float x);
long double    expm1l(long double x);

floating_point fdim (arithmetic x, arithmetic y);
float          fdimf(float x, float y);
long double    fdiml(long double x, long double y);

floating_point fma (arithmetic x, arithmetic y, arithmetic z);
float          fmaf(float x, float y, float z);
long double    fmal(long double x, long double y, long double z);

floating_point fmax (arithmetic x, arithmetic y);
float          fmaxf(float x, float y);
long double    fmaxl(long double x, long double y);

floating_point fmin (arithmetic x, arithmetic y);
float          fminf(float x, float y);
long double    fminl(long double x, long double y);

floating_point hypot (arithmetic x, arithmetic y);
float          hypotf(float x, float y);
long double    hypotl(long double x, long double y);

double       hypot(double x, double y, double z);                // C++17
float        hypot(float x, float y, float z);                   // C++17
long double  hypot(long double x, long double y, long double z); // C++17

int ilogb (arithmetic x);
int ilogbf(float x);
int ilogbl(long double x);

floating_point lgamma (arithmetic x);
float          lgammaf(float x);
long double    lgammal(long double x);

long long llrint (arithmetic x);
long long llrintf(float x);
long long llrintl(long double x);

long long llround (arithmetic x);
long long llroundf(float x);
long long llroundl(long double x);

floating_point log1p (arithmetic x);
float          log1pf(float x);
long double    log1pl(long double x);

floating_point log2 (arithmetic x);
float          log2f(float x);
long double    log2l(long double x);

floating_point logb (arithmetic x);
float          logbf(float x);
long double    logbl(long double x);

long lrint (arithmetic x);
long lrintf(float x);
long lrintl(long double x);

long lround (arithmetic x);
long lroundf(float x);
long lroundl(long double x);

double      nan (const char* str);
float       nanf(const char* str);
long double nanl(const char* str);

floating_point nearbyint (arithmetic x);
float          nearbyintf(float x);
long double    nearbyintl(long double x);

floating_point nextafter (arithmetic x, arithmetic y);
float          nextafterf(float x, float y);
long double    nextafterl(long double x, long double y);

floating_point nexttoward (arithmetic x, long double y);
float          nexttowardf(float x, long double y);
long double    nexttowardl(long double x, long double y);

floating_point remainder (arithmetic x, arithmetic y);
float          remainderf(float x, float y);
long double    remainderl(long double x, long double y);

floating_point remquo (arithmetic x, arithmetic y, int* pquo);
float          remquof(float x, float y, int* pquo);
long double    remquol(long double x, long double y, int* pquo);

floating_point rint (arithmetic x);
float          rintf(float x);
long double    rintl(long double x);

floating_point round (arithmetic x);
float          roundf(float x);
long double    roundl(long double x);

floating_point scalbln (arithmetic x, long ex);
float          scalblnf(float x, long ex);
long double    scalblnl(long double x, long ex);

floating_point scalbn (arithmetic x, int ex);
float          scalbnf(float x, int ex);
long double    scalbnl(long double x, int ex);

floating_point tgamma (arithmetic x);
float          tgammaf(float x);
long double    tgammal(long double x);

floating_point trunc (arithmetic x);
float          truncf(float x);
long double    truncl(long double x);

}  // std

*/

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#if !_CCCL_COMPILER(NVRTC)
#  include <math.h>
#endif // !_CCCL_COMPILER(NVRTC)

#if _CCCL_COMPILER(NVHPC)
#  include <cmath>
#endif // _CCCL_COMPILER(NVHPC)

#include <cuda/std/__cmath/abs.h>
#include <cuda/std/__cmath/copysign.h>
#include <cuda/std/__cmath/exponential_functions.h>
#include <cuda/std/__cmath/fma.h>
#include <cuda/std/__cmath/fpclassify.h>
#include <cuda/std/__cmath/gamma.h>
#include <cuda/std/__cmath/hyperbolic_functions.h>
#include <cuda/std/__cmath/hypot.h>
#include <cuda/std/__cmath/inverse_hyperbolic_functions.h>
#include <cuda/std/__cmath/inverse_trigonometric_functions.h>
#include <cuda/std/__cmath/isfinite.h>
#include <cuda/std/__cmath/isinf.h>
#include <cuda/std/__cmath/isnan.h>
#include <cuda/std/__cmath/isnormal.h>
#include <cuda/std/__cmath/lerp.h>
#include <cuda/std/__cmath/logarithms.h>
#include <cuda/std/__cmath/min_max.h>
#include <cuda/std/__cmath/modulo.h>
#include <cuda/std/__cmath/remainder.h>
#include <cuda/std/__cmath/roots.h>
#include <cuda/std/__cmath/rounding_functions.h>
#include <cuda/std/__cmath/signbit.h>
#include <cuda/std/__cmath/traits.h>
#include <cuda/std/__cmath/trigonometric_functions.h>
#include <cuda/std/__cstdlib/abs.h>
#include <cuda/std/limits>
#include <cuda/std/type_traits>
#include <cuda/std/version>

#if _CCCL_COMPILER(NVRTC)
#  define INFINITY ::cuda::std::numeric_limits<float>::infinity()
#  define NAN      ::cuda::std::numeric_limits<float>::quiet_NaN()
#endif // _CCCL_COMPILER(NVRTC)

#include <cuda/std/__cccl/prologue.h>

_CCCL_BEGIN_NAMESPACE_CUDA_STD

#if !_CCCL_COMPILER(NVRTC)

using ::double_t;
using ::float_t;

using ::erf;
using ::erfc;
using ::erfcf;
using ::erff;
using ::fdim;
using ::fdimf;

using ::nan;
using ::nanf;

using ::erfcl;
using ::erfl;
using ::fdiml;
using ::nanl;

#endif // _CCCL_COMPILER(NVRTC)

#if _LIBCUDACXX_HAS_CONSTEXPR_COMPLEX_OPERATIONS()
#  define _CCCL_CONSTEXPR_CXX14_COMPLEX constexpr
#else
#  define _CCCL_CONSTEXPR_CXX14_COMPLEX
#endif // !_LIBCUDACXX_HAS_CONSTEXPR_COMPLEX_OPERATIONS()

#if _CCCL_COMPILER(MSVC) || _CCCL_COMPILER(NVRTC) || _CCCL_CUDA_COMPILER(CLANG)
template <class _A1>
_CCCL_API inline _A1 __constexpr_logb(_A1 __x)
{
  return ::cuda::std::logb(__x);
}
#else
template <class _Tp>
_CCCL_API inline _CCCL_CONSTEXPR_CXX14_COMPLEX _Tp __constexpr_logb(_Tp __x)
{
#  if defined(_CCCL_BUILTIN_IS_CONSTANT_EVALUATED) && _LIBCUDACXX_HAS_CONSTEXPR_COMPLEX_OPERATIONS()
  if (_CCCL_BUILTIN_IS_CONSTANT_EVALUATED())
  {
    if (__x == _Tp(0))
    {
      // raise FE_DIVBYZERO
      return -numeric_limits<_Tp>::infinity();
    }

    if (::cuda::std::isinf(__x))
    {
      return numeric_limits<_Tp>::infinity();
    }

    if (::cuda::std::isnan(__x))
    {
      return numeric_limits<_Tp>::quiet_NaN();
    }

    __x                      = ::cuda::std::fabs(__x);
    unsigned long long __exp = 0;
    while (__x >= _Tp(numeric_limits<_Tp>::radix))
    {
      __x /= numeric_limits<_Tp>::radix;
      __exp += 1;
    }
    return _Tp(__exp);
  }
#  endif // defined(_CCCL_BUILTIN_IS_CONSTANT_EVALUATED)
  return __builtin_logb(__x);
}
#endif // !_MSVC

#if _CCCL_COMPILER(MSVC) || _CCCL_COMPILER(NVRTC) || _CCCL_CUDA_COMPILER(CLANG)
template <class _Tp>
_CCCL_API inline _Tp __constexpr_scalbn(_Tp __x, int __i)
{
  return static_cast<_Tp>(::scalbn(static_cast<double>(__x), __i));
}

template <>
_CCCL_API inline float __constexpr_scalbn<float>(float __x, int __i)
{
  return ::scalbnf(__x, __i);
}

template <>
_CCCL_API inline double __constexpr_scalbn<double>(double __x, int __i)
{
  return ::scalbn(__x, __i);
}

#  if _CCCL_HAS_LONG_DOUBLE()
template <>
_CCCL_API inline long double __constexpr_scalbn<long double>(long double __x, int __i)
{
  return ::scalbnl(__x, __i);
}
#  endif // _CCCL_HAS_LONG_DOUBLE()
#else
template <class _Tp>
_CCCL_API inline _CCCL_CONSTEXPR_CXX14_COMPLEX _Tp __constexpr_scalbn(_Tp __x, int __exp)
{
#  if defined(_CCCL_BUILTIN_IS_CONSTANT_EVALUATED) && _LIBCUDACXX_HAS_CONSTEXPR_COMPLEX_OPERATIONS()
  if (_CCCL_BUILTIN_IS_CONSTANT_EVALUATED())
  {
    if (__x == _Tp(0))
    {
      return __x;
    }

    if (::cuda::std::isinf(__x))
    {
      return __x;
    }

    if (_Tp(__exp) == _Tp(0))
    {
      return __x;
    }

    if (::cuda::std::isnan(__x))
    {
      return numeric_limits<_Tp>::quiet_NaN();
    }

    _Tp __mult(1);
    if (__exp > 0)
    {
      __mult = numeric_limits<_Tp>::radix;
      --__exp;
    }
    else
    {
      ++__exp;
      __exp = -__exp;
      __mult /= numeric_limits<_Tp>::radix;
    }

    while (__exp > 0)
    {
      if (!(__exp & 1))
      {
        __mult *= __mult;
        __exp >>= 1;
      }
      else
      {
        __x *= __mult;
        --__exp;
      }
    }
    return __x;
  }
#  endif // defined(_CCCL_BUILTIN_IS_CONSTANT_EVALUATED)
  return __builtin_scalbn(__x, __exp);
}
#endif // !_CCCL_COMPILER(MSVC)

_CCCL_END_NAMESPACE_CUDA_STD

#include <cuda/std/__cccl/epilogue.h>

#endif // _LIBCUDACXX_CMATH
