// SPDX-FileCopyrightText: Copyright (c) 2008-2020, NVIDIA Corporation. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <thrust/detail/config.h>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header
#include <thrust/detail/type_traits.h>
#include <thrust/detail/type_traits/has_nested_type.h>
#include <thrust/detail/type_traits/is_thrust_pointer.h>
#include <thrust/iterator/iterator_traits.h>

#include <cuda/std/__type_traits/add_lvalue_reference.h>
#include <cuda/std/__type_traits/conjunction.h>
#include <cuda/std/__type_traits/enable_if.h>
#include <cuda/std/__type_traits/is_comparable.h>
#include <cuda/std/__type_traits/is_convertible.h>
#include <cuda/std/__type_traits/is_same.h>
#include <cuda/std/__type_traits/is_void.h>
#include <cuda/std/__type_traits/type_identity.h>
#include <cuda/std/cstddef>

THRUST_NAMESPACE_BEGIN
namespace detail
{
template <typename Ptr>
struct pointer_element;

template <template <typename...> class Ptr, typename FirstArg, typename... Args>
struct pointer_element<Ptr<FirstArg, Args...>>
{
  using type = FirstArg;
};

template <typename T>
struct pointer_element<T*>
{
  using type = T;
};

template <typename Ptr>
struct pointer_difference
{
  using type = typename Ptr::difference_type;
};

template <typename T>
struct pointer_difference<T*>
{
  using type = ::cuda::std::ptrdiff_t;
};

template <typename Ptr, typename T>
struct rebind_pointer;

template <typename T, typename U>
struct rebind_pointer<T*, U>
{
  using type = U*;
};

// Rebind generic fancy pointers.
template <template <typename, typename...> class Ptr, typename OldT, typename... Tail, typename T>
struct rebind_pointer<Ptr<OldT, Tail...>, T>
{
  using type = Ptr<T, Tail...>;
};

// Rebind `thrust::pointer`-like things with `thrust::reference`-like references.
template <template <typename, typename, typename, typename...> class Ptr,
          typename OldT,
          typename Tag,
          template <typename...> class Ref,
          typename... RefTail,
          typename... PtrTail,
          typename T>
struct rebind_pointer<Ptr<OldT, Tag, Ref<OldT, RefTail...>, PtrTail...>, T>
{
  //  static_assert(::cuda::std::is_same<OldT, Tag>::value, "0");
  using type = Ptr<T, Tag, Ref<T, RefTail...>, PtrTail...>;
};

// Rebind `thrust::pointer`-like things with `thrust::reference`-like references
// and templated derived types.
template <template <typename, typename, typename, typename...> class Ptr,
          typename OldT,
          typename Tag,
          template <typename...> class Ref,
          typename... RefTail,
          template <typename...> class DerivedPtr,
          typename... DerivedPtrTail,
          typename T>
struct rebind_pointer<Ptr<OldT, Tag, Ref<OldT, RefTail...>, DerivedPtr<OldT, DerivedPtrTail...>>, T>
{
  //  static_assert(::cuda::std::is_same<OldT, Tag>::value, "1");
  using type = Ptr<T, Tag, Ref<T, RefTail...>, DerivedPtr<T, DerivedPtrTail...>>;
};

// Rebind `thrust::pointer`-like things with native reference types.
template <template <typename, typename, typename, typename...> class Ptr,
          typename OldT,
          typename Tag,
          typename... PtrTail,
          typename T>
struct rebind_pointer<Ptr<OldT, Tag, ::cuda::std::add_lvalue_reference_t<OldT>, PtrTail...>, T>
{
  //  static_assert(::cuda::std::is_same<OldT, Tag>::value, "2");
  using type = Ptr<T, Tag, ::cuda::std::add_lvalue_reference_t<T>, PtrTail...>;
};

// Rebind `thrust::pointer`-like things with native reference types and templated
// derived types.
template <template <typename, typename, typename, typename...> class Ptr,
          typename OldT,
          typename Tag,
          template <typename...> class DerivedPtr,
          typename... DerivedPtrTail,
          typename T>
struct rebind_pointer<Ptr<OldT, Tag, ::cuda::std::add_lvalue_reference_t<OldT>, DerivedPtr<OldT, DerivedPtrTail...>>, T>
{
  //  static_assert(::cuda::std::is_same<OldT, Tag>::value, "3");
  using type = Ptr<T, Tag, ::cuda::std::add_lvalue_reference_t<T>, DerivedPtr<T, DerivedPtrTail...>>;
};

namespace pointer_traits_detail
{
template <typename Void>
struct capture_address
{
  template <typename T>
  _CCCL_HOST_DEVICE capture_address(T& r)
      : m_addr(&r)
  {}

  inline _CCCL_HOST_DEVICE Void* operator&() const
  {
    return m_addr;
  }

  Void* m_addr;
};

// metafunction to compute the type of pointer_to's parameter below
template <typename T>
struct pointer_to_param
    : thrust::detail::eval_if<::cuda::std::is_void_v<T>,
                              ::cuda::std::type_identity<capture_address<T>>,
                              ::cuda::std::add_lvalue_reference<T>>
{};
} // namespace pointer_traits_detail

template <typename Ptr>
struct pointer_traits
{
  using pointer         = Ptr;
  using reference       = typename Ptr::reference;
  using element_type    = typename pointer_element<Ptr>::type;
  using difference_type = typename pointer_difference<Ptr>::type;

  template <typename U>
  struct rebind
  {
    using other = typename rebind_pointer<Ptr, U>::type;
  };

  _CCCL_HOST_DEVICE inline static pointer
  pointer_to(typename pointer_traits_detail::pointer_to_param<element_type>::type r)
  {
    // XXX this is supposed to be pointer::pointer_to(&r); (i.e., call a static member function of pointer called
    // pointer_to)
    //     assume that pointer has a constructor from raw pointer instead

    return pointer(&r);
  }

  // thrust additions follow
  using raw_pointer = typename pointer_raw_pointer<Ptr>::type;

  _CCCL_HOST_DEVICE inline static raw_pointer get(pointer ptr)
  {
    return ptr.get();
  }
};

template <typename T>
struct pointer_traits<T*>
{
  using pointer         = T*;
  using reference       = T&;
  using element_type    = T;
  using difference_type = typename pointer_difference<T*>::type;

  template <typename U>
  struct rebind
  {
    using other = U*;
  };

  _CCCL_HOST_DEVICE inline static pointer
  pointer_to(typename pointer_traits_detail::pointer_to_param<element_type>::type r)
  {
    return &r;
  }

  // thrust additions follow
  using raw_pointer = typename pointer_raw_pointer<T*>::type;

  _CCCL_HOST_DEVICE inline static raw_pointer get(pointer ptr)
  {
    return ptr;
  }
};

template <>
struct pointer_traits<void*>
{
  using pointer         = void*;
  using reference       = void;
  using element_type    = void;
  using difference_type = pointer_difference<void*>::type;

  template <typename U>
  struct rebind
  {
    using other = U*;
  };

  _CCCL_HOST_DEVICE inline static pointer pointer_to(pointer_traits_detail::pointer_to_param<element_type>::type r)
  {
    return &r;
  }

  // thrust additions follow
  using raw_pointer = pointer_raw_pointer<void*>::type;

  _CCCL_HOST_DEVICE inline static raw_pointer get(pointer ptr)
  {
    return ptr;
  }
};

template <>
struct pointer_traits<const void*>
{
  using pointer         = const void*;
  using reference       = const void;
  using element_type    = const void;
  using difference_type = pointer_difference<const void*>::type;

  template <typename U>
  struct rebind
  {
    using other = U*;
  };

  _CCCL_HOST_DEVICE inline static pointer pointer_to(pointer_traits_detail::pointer_to_param<element_type>::type r)
  {
    return &r;
  }

  // thrust additions follow
  using raw_pointer = pointer_raw_pointer<const void*>::type;

  _CCCL_HOST_DEVICE inline static raw_pointer get(pointer ptr)
  {
    return ptr;
  }
};

template <typename FromPtr, typename ToPtr>
inline constexpr bool is_pointer_system_convertible_v =
  ::cuda::std::is_convertible_v<iterator_system_t<FromPtr>, iterator_system_t<ToPtr>>;

template <typename FromPtr, typename ToPtr>
inline constexpr bool is_pointer_convertible_v =
  ::cuda::std::is_convertible_v<typename pointer_element<FromPtr>::type*, typename pointer_element<ToPtr>::type*>
  && is_pointer_system_convertible_v<FromPtr, ToPtr>;

template <typename FromPtr, typename ToPtr>
inline constexpr bool is_void_pointer_system_convertible_v =
  ::cuda::std::is_void_v<typename pointer_element<FromPtr>::type> //
  && is_pointer_system_convertible_v<FromPtr, ToPtr>;

// avoid inspecting traits of the arguments if they aren't known to be pointers
template <typename FromPtr, typename ToPtr, bool = is_thrust_pointer_v<FromPtr> && is_thrust_pointer_v<ToPtr>>
inline constexpr bool lazy_is_pointer_convertible_v = false;

template <typename FromPtr, typename ToPtr>
inline constexpr bool lazy_is_pointer_convertible_v<FromPtr, ToPtr, true> = is_pointer_convertible_v<FromPtr, ToPtr>;

template <typename FromPtr, typename ToPtr, bool = is_thrust_pointer_v<FromPtr> && is_thrust_pointer_v<ToPtr>>
inline constexpr bool lazy_is_void_pointer_system_convertible_v = false;

template <typename FromPtr, typename ToPtr>
inline constexpr bool lazy_is_void_pointer_system_convertible_v<FromPtr, ToPtr, true> =
  is_void_pointer_system_convertible_v<FromPtr, ToPtr>;

template <typename FromPtr, typename ToPtr, typename T = void>
using enable_if_pointer_is_convertible_t = ::cuda::std::enable_if_t<lazy_is_pointer_convertible_v<FromPtr, ToPtr>, T>;

template <typename FromPtr, typename ToPtr, typename T = void>
using enable_if_void_pointer_is_system_convertible_t =
  ::cuda::std::enable_if_t<lazy_is_void_pointer_system_convertible_v<FromPtr, ToPtr>, T>;

// tagged pointers can only compare if they have a matching system and comparable pointer type
template <typename FromPtr, typename ToPtr>
inline constexpr bool ptr_can_compare_equal =
  is_pointer_system_convertible_v<FromPtr, ToPtr>
  && ::cuda::std::__is_cpp17_equality_comparable_v<typename FromPtr::value_type*, typename ToPtr::value_type*>;

template <typename FromPtr, typename ToPtr>
inline constexpr bool ptr_can_compare_less_than =
  is_pointer_system_convertible_v<FromPtr, ToPtr>
  && ::cuda::std::__is_cpp17_less_than_comparable_v<typename FromPtr::value_type*, typename ToPtr::value_type*>;

// tagged references can only compare if they have a matching system and comparable value type
template <typename FromPtr, typename ToPtr>
inline constexpr bool ref_can_compare_equal =
  is_pointer_system_convertible_v<FromPtr, ToPtr>
  && ::cuda::std::__is_cpp17_equality_comparable_v<typename FromPtr::value_type, typename ToPtr::value_type>;

template <typename FromPtr, typename ToPtr>
inline constexpr bool ref_can_compare_less_than =
  is_pointer_system_convertible_v<FromPtr, ToPtr>
  && ::cuda::std::__is_cpp17_less_than_comparable_v<typename FromPtr::value_type, typename ToPtr::value_type>;
} // namespace detail
THRUST_NAMESPACE_END
