Program Listing for File complex.h

Return to documentation for file (thrust/complex.h)

/*
 *  Copyright 2008-2019 NVIDIA Corporation
 *  Copyright 2013 Filipe RNC Maia
 *  Modifications Copyright© 2019 Advanced Micro Devices, Inc. All rights reserved.
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

#pragma once

#include <thrust/detail/config.h>

#include <cmath>
#include <complex>
#include <sstream>
#include <thrust/detail/type_traits.h>

#if THRUST_CPP_DIALECT >= 2011
#  define THRUST_STD_COMPLEX_REAL(z) \
    reinterpret_cast< \
      const typename thrust::detail::remove_reference<decltype(z)>::type::value_type (&)[2] \
    >(z)[0]
#  define THRUST_STD_COMPLEX_IMAG(z) \
    reinterpret_cast< \
      const typename thrust::detail::remove_reference<decltype(z)>::type::value_type (&)[2] \
    >(z)[1]
#  define THRUST_STD_COMPLEX_DEVICE __device__
#else
#  define THRUST_STD_COMPLEX_REAL(z) (z).real()
#  define THRUST_STD_COMPLEX_IMAG(z) (z).imag()
#  define THRUST_STD_COMPLEX_DEVICE
#endif

namespace thrust
{

/*
 *  Calls to the standard math library from inside the thrust namespace
 *  with real arguments require explicit scope otherwise they will fail
 *  to resolve as it will find the equivalent complex function but then
 *  fail to match the template, and give up looking for other scopes.
 */


template <typename T>
struct complex
{
public:

  typedef T value_type;



  /* --- Constructors --- */

  __host__ __device__
  complex(const T& re);

  __host__ __device__
  complex(const T& re, const T& im);

#if THRUST_CPP_DIALECT >= 2011

   __host__ __device__
  complex() = default;

   __host__ __device__
  complex(const complex<T>& z) = default;
#else

  __host__ __device__
  complex();

  __host__ __device__
  complex(const complex<T>& z);
#endif

  template <typename U>
  __host__ __device__
  complex(const complex<U>& z);

  __host__ THRUST_STD_COMPLEX_DEVICE
  complex(const std::complex<T>& z);

  template <typename U>
  __host__ THRUST_STD_COMPLEX_DEVICE
  complex(const std::complex<U>& z);



  /* --- Assignment Operators --- */

  __host__ __device__
  complex& operator=(const T& re);

#if THRUST_CPP_DIALECT >= 2011

   __host__ __device__
  complex& operator=(const complex<T>& z) = default;
#else

  __host__ __device__
  complex& operator=(const complex<T>& z);
#endif

  template <typename U>
  __host__ __device__
  complex& operator=(const complex<U>& z);

  __host__ THRUST_STD_COMPLEX_DEVICE
  complex& operator=(const std::complex<T>& z);

  template <typename U>
  __host__ THRUST_STD_COMPLEX_DEVICE
  complex& operator=(const std::complex<U>& z);


  /* --- Compound Assignment Operators --- */

  template <typename U>
  __host__ __device__
  complex<T>& operator+=(const complex<U>& z);

  template <typename U>
  __host__ __device__
  complex<T>& operator-=(const complex<U>& z);

  template <typename U>
  __host__ __device__
  complex<T>& operator*=(const complex<U>& z);

  template <typename U>
  __host__ __device__
  complex<T>& operator/=(const complex<U>& z);

  template <typename U>
  __host__ __device__
  complex<T>& operator+=(const U& z);

  template <typename U>
  __host__ __device__
  complex<T>& operator-=(const U& z);

  template <typename U>
  __host__ __device__
  complex<T>& operator*=(const U& z);

  template <typename U>
  __host__ __device__
  complex<T>& operator/=(const U& z);



  /* --- Getter functions ---
   * The volatile ones are there to help for example
   * with certain reductions optimizations
   */

  __host__ __device__
  T real() const volatile { return data.x; }

  __host__ __device__
  T imag() const volatile { return data.y; }

  __host__ __device__
  T real() const { return data.x; }

  __host__ __device__
  T imag() const { return data.y; }



  /* --- Setter functions ---
   * The volatile ones are there to help for example
   * with certain reductions optimizations
   */

  __host__ __device__
  void real(T re) volatile { data.x = re; }

  __host__ __device__
  void imag(T im) volatile { data.y = im; }

  __host__ __device__
  void real(T re) { data.x = re; }

  __host__ __device__
  void imag(T im) { data.y = im; }



  /* --- Casting functions --- */

  __host__ __device__
  operator std::complex<T>() const { return std::complex<T>(real(), imag()); }

private:
  struct generic_storage_type { T x; T y; };
#if THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_NVCC || THRUST_DEVICE_COMPILER == THRUST_DEVICE_COMPILER_HCC
  typedef typename detail::conditional<
    detail::is_same<T, float>::value, float2,
    typename detail::conditional<
      detail::is_same<T, float const>::value, float2 const,
      typename detail::conditional<
        detail::is_same<T, double>::value, double2,
        typename detail::conditional<
          detail::is_same<T, double const>::value, double2 const,
          generic_storage_type
        >::type
      >::type
    >::type
  >::type storage_type;
#else
  typedef generic_storage_type storage_type;
#endif

  storage_type data;
};


/* --- General Functions --- */

template<typename T>
__host__ __device__
T abs(const complex<T>& z);

template <typename T>
__host__ __device__
T arg(const complex<T>& z);

template <typename T>
__host__ __device__
T norm(const complex<T>& z);

template <typename T>
__host__ __device__
complex<T> conj(const complex<T>& z);

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
polar(const T0& m, const T1& theta = T1());

template <typename T>
__host__ __device__
complex<T> proj(const T& z);



/* --- Binary Arithmetic operators --- */

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator+(const complex<T0>& x, const complex<T1>& y);

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator+(const complex<T0>& x, const T1& y);

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator+(const T0& x, const complex<T1>& y);

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator-(const complex<T0>& x, const complex<T1>& y);

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator-(const complex<T0>& x, const T1& y);

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator-(const T0& x, const complex<T1>& y);

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator*(const complex<T0>& x, const complex<T1>& y);

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator*(const complex<T0>& x, const T1& y);

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator*(const T0& x, const complex<T1>& y);

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator/(const complex<T0>& x, const complex<T1>& y);

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator/(const complex<T0>& x, const T1& y);

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
operator/(const T0& x, const complex<T1>& y);



/* --- Unary Arithmetic operators --- */

template <typename T>
__host__ __device__
complex<T>
operator+(const complex<T>& y);

template <typename T>
__host__ __device__
complex<T>
operator-(const complex<T>& y);



/* --- Exponential Functions --- */

template <typename T>
__host__ __device__
complex<T> exp(const complex<T>& z);

template <typename T>
__host__ __device__
complex<T> log(const complex<T>& z);

template <typename T>
__host__ __device__
complex<T> log10(const complex<T>& z);



/* --- Power Functions --- */

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
pow(const complex<T0>& x, const complex<T1>& y);

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
pow(const complex<T0>& x, const T1& y);

template <typename T0, typename T1>
__host__ __device__
complex<typename detail::promoted_numerical_type<T0, T1>::type>
pow(const T0& x, const complex<T1>& y);

template <typename T>
__host__ __device__
complex<T> sqrt(const complex<T>& z);


/* --- Trigonometric Functions --- */

template <typename T>
__host__ __device__
complex<T> cos(const complex<T>& z);

template <typename T>
__host__ __device__
complex<T> sin(const complex<T>& z);

template <typename T>
__host__ __device__
complex<T> tan(const complex<T>& z);



/* --- Hyperbolic Functions --- */

template <typename T>
__host__ __device__
complex<T> cosh(const complex<T>& z);

template <typename T>
__host__ __device__
complex<T> sinh(const complex<T>& z);

template <typename T>
__host__ __device__
complex<T> tanh(const complex<T>& z);



/* --- Inverse Trigonometric Functions --- */

template <typename T>
__host__ __device__
complex<T> acos(const complex<T>& z);

template <typename T>
__host__ __device__
complex<T> asin(const complex<T>& z);

template <typename T>
__host__ __device__
complex<T> atan(const complex<T>& z);



/* --- Inverse Hyperbolic Functions --- */

template <typename T>
__host__ __device__
complex<T> acosh(const complex<T>& z);

template <typename T>
__host__ __device__
complex<T> asinh(const complex<T>& z);

template <typename T>
__host__ __device__
complex<T> atanh(const complex<T>& z);



/* --- Stream Operators --- */

template <typename T, typename CharT, typename Traits>
std::basic_ostream<CharT, Traits>&
operator<<(std::basic_ostream<CharT, Traits>& os, const complex<T>& z);

template <typename T, typename CharT, typename Traits>
std::basic_istream<CharT, Traits>&
operator>>(std::basic_istream<CharT, Traits>& is, complex<T>& z);



/* --- Equality Operators --- */

template <typename T0, typename T1>
__host__ __device__
bool operator==(const complex<T0>& x, const complex<T1>& y);

template <typename T0, typename T1>
__host__ THRUST_STD_COMPLEX_DEVICE
bool operator==(const complex<T0>& x, const std::complex<T1>& y);

template <typename T0, typename T1>
__host__ THRUST_STD_COMPLEX_DEVICE
bool operator==(const std::complex<T0>& x, const complex<T1>& y);

template <typename T0, typename T1>
__host__ __device__
bool operator==(const T0& x, const complex<T1>& y);

template <typename T0, typename T1>
__host__ __device__
bool operator==(const complex<T0>& x, const T1& y);

template <typename T0, typename T1>
__host__ __device__
bool operator!=(const complex<T0>& x, const complex<T1>& y);

template <typename T0, typename T1>
__host__ THRUST_STD_COMPLEX_DEVICE
bool operator!=(const complex<T0>& x, const std::complex<T1>& y);

template <typename T0, typename T1>
__host__ THRUST_STD_COMPLEX_DEVICE
bool operator!=(const std::complex<T0>& x, const complex<T1>& y);

template <typename T0, typename T1>
__host__ __device__
bool operator!=(const T0& x, const complex<T1>& y);

template <typename T0, typename T1>
__host__ __device__
bool operator!=(const complex<T0>& x, const T1& y);

} // end namespace thrust

#include <thrust/detail/complex/complex.inl>

#undef THRUST_STD_COMPLEX_REAL
#undef THRUST_STD_COMPLEX_IMAG
#undef THRUST_STD_COMPLEX_DEVICE