Преглед на файлове

merge Libin's upgrade for AVX512

Dhairya Malhotra преди 5 години
родител
ревизия
51f7133d80
променени са 1 файла, в които са добавени 259 реда и са изтрити 4 реда
  1. 259 4
      include/sctl/vec.hpp

+ 259 - 4
include/sctl/vec.hpp

@@ -334,8 +334,6 @@ namespace SCTL_NAMESPACE {
       }
 
       ValueType v[N];
-      friend class Vec<IntegerType,N>;
-      friend class Vec<RealType,N>;
   };
 
   // Other operators
@@ -518,7 +516,7 @@ namespace SCTL_NAMESPACE {
     RealVec x1 = x - x_ * x0; // 2 - cycles
     RealVec x2, x3, x4, x5, x6, x7, x8, x9, x10;
 
-    RealVec e1 = 1 + x1;
+    RealVec e1 = 1.0 + x1;
     if (ORDER >= 2) {
       x2 = x1 * x1;
       e1 += x2 * coeff2;
@@ -695,7 +693,6 @@ namespace SCTL_NAMESPACE {
       friend Vec operator==(Vec lhs, const Vec& rhs) {
         lhs.v = _mm256_cmp_pd(lhs.v, rhs.v, _CMP_EQ_OS);
         return lhs;
-        return lhs;
       }
       friend Vec operator!=(Vec lhs, const Vec& rhs) {
         lhs.v = _mm256_cmp_pd(lhs.v, rhs.v, _CMP_NEQ_OS);
@@ -809,6 +806,264 @@ namespace SCTL_NAMESPACE {
 
 #endif
 
+#ifdef __AVX512F__
+  template <> class alignas(sizeof(double)*8) Vec<double,8> {
+    typedef __m512d VecType;
+    typedef double ValueType;
+    static constexpr Integer N = 8;
+    public:
+
+      typedef typename GetType<DataType::Integer,TypeTraits<ValueType>::Size>::ValueType IntegerType;
+      typedef typename GetType<DataType::Real,TypeTraits<ValueType>::Size>::ValueType RealType;
+      typedef Vec<IntegerType,N> IntegerVec;
+      typedef Vec<RealType,N> RealVec;
+      typedef ValueType ScalarType;
+
+      static constexpr Integer Size() {
+        return N;
+      }
+
+      static Vec Zero() {
+        Vec r;
+        r.v = _mm512_setzero_pd();
+        return r;
+      }
+
+      static Vec Load1(ValueType const* p) {
+        Vec r;
+        // TODO: different from _m256d, could make it faster?
+        // r.v = _mm512_broadcast_f64x4(_mm256_broadcast_sd(p));
+        r.v = _mm512_set1_pd(*p);
+        return r;
+      }
+      static Vec Load(ValueType const* p) {
+        Vec r;
+        r.v = _mm512_loadu_pd(p);
+        return r;
+      }
+      static Vec LoadAligned(ValueType const* p) {
+        Vec r;
+        r.v = _mm512_load_pd(p);
+        return r;
+      }
+
+      Vec() = default;
+
+      Vec(const ValueType& a) {
+        v = _mm512_set1_pd(a);
+      }
+
+      //Vec(const __mmask8& a) {
+      //  v = _mm512_castsi512_pd(_mm512_movm_epi64(a));
+      //}
+
+      void Store(ValueType* p) const {
+        _mm512_storeu_pd(p, v);
+      }
+      void StoreAligned(ValueType* p) const {
+        _mm512_store_pd(p, v);
+      }
+
+      // Bitwise NOT
+      Vec operator~() const {
+        Vec r;
+        static constexpr ScalarType Creal = -1.0;
+        r.v = _mm512_xor_pd(v, _mm512_set1_pd(Creal));
+        return r;
+      }
+
+      // Unary plus and minus
+      Vec operator+() const {
+        return *this;
+      }
+      Vec operator-() const {
+        return Zero() - (*this);
+      }
+
+      // C-style cast
+      //template <class RetValueType> explicit operator Vec<RetValueType,N>() const {
+      //}
+
+      // Arithmetic operators
+      friend Vec operator*(Vec lhs, const Vec& rhs) {
+        lhs.v = _mm512_mul_pd(lhs.v, rhs.v);
+        return lhs;
+      }
+      friend Vec operator+(Vec lhs, const Vec& rhs) {
+        lhs.v = _mm512_add_pd(lhs.v, rhs.v);
+        return lhs;
+      }
+      friend Vec operator-(Vec lhs, const Vec& rhs) {
+        lhs.v = _mm512_sub_pd(lhs.v, rhs.v);
+        return lhs;
+      }
+      friend Vec FMA(Vec a, const Vec& b, const Vec& c) {
+        a.v = _mm512_fmadd_pd(a.v, b.v, c.v);
+        //a.v = _mm512_add_pd(_mm512_mul_pd(a.v, b.v), c.v);
+        return a;
+      }
+
+      // Comparison operators
+      //friend Vec operator< (Vec lhs, const Vec& rhs) {
+      //  lhs.v = _mm512_castsi512_pd(_mm512_movm_epi64(_mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_LT_OS)));
+      //  return lhs;
+      //}
+      //friend Vec operator<=(Vec lhs, const Vec& rhs) {
+      //  lhs.v = _mm512_castsi512_pd(_mm512_movm_epi64(_mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_LE_OS)));
+      //  return lhs;
+      //}
+      //friend Vec operator>=(Vec lhs, const Vec& rhs) {
+      //  lhs.v = _mm512_castsi512_pd(_mm512_movm_epi64(_mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_GE_OS)));
+      //  return lhs;
+      //}
+      //friend Vec operator> (Vec lhs, const Vec& rhs) {
+      //  lhs.v = _mm512_castsi512_pd(_mm512_movm_epi64(_mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_GT_OS)));
+      //  return lhs;
+      //}
+      //friend Vec operator==(Vec lhs, const Vec& rhs) {
+      //  lhs.v = _mm512_castsi512_pd(_mm512_movm_epi64(_mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_EQ_OS)));
+      //  return lhs;
+      //}
+      //friend Vec operator!=(Vec lhs, const Vec& rhs) {
+      //  lhs.v = _mm512_castsi512_pd(_mm512_movm_epi64(_mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_NEQ_OS)));
+      //  return lhs;
+      //}
+
+      friend __mmask8 operator< (Vec lhs, const Vec& rhs) {
+        return _mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_LT_OS);
+      }
+      friend __mmask8 operator<=(Vec lhs, const Vec& rhs) {
+        return _mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_LE_OS);
+      }
+      friend __mmask8 operator>=(Vec lhs, const Vec& rhs) {
+        return _mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_GE_OS);
+      }
+      friend __mmask8 operator> (Vec lhs, const Vec& rhs) {
+        return _mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_GT_OS);
+      }
+      friend __mmask8 operator==(Vec lhs, const Vec& rhs) {
+        return _mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_EQ_OS);
+      }
+      friend __mmask8 operator!=(Vec lhs, const Vec& rhs) {
+        return _mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_NEQ_OS);
+      }
+
+      // Bitwise operators
+      friend Vec operator&(Vec lhs, const Vec& rhs) {
+        lhs.v = _mm512_and_pd(lhs.v, rhs.v);
+        return lhs;
+      }
+      friend Vec operator^(Vec lhs, const Vec& rhs) {
+        lhs.v = _mm512_xor_pd(lhs.v, rhs.v);
+        return lhs;
+      }
+      friend Vec operator|(Vec lhs, const Vec& rhs) {
+        lhs.v = _mm512_or_pd(lhs.v, rhs.v);
+        return lhs;
+      }
+      friend Vec AndNot(Vec lhs, const Vec& rhs) {
+        lhs.v = _mm512_andnot_pd(rhs.v, lhs.v);
+        return lhs;
+      }
+      friend Vec operator&(Vec lhs, const __mmask8& rhs) {
+        lhs.v = _mm512_maskz_mov_pd(rhs, lhs.v);
+        return lhs;
+      }
+
+      // Assignment operators
+      Vec& operator*=(const Vec& rhs) {
+        v = _mm512_mul_pd(v, rhs.v);
+        return *this;
+      }
+      Vec& operator+=(const Vec& rhs) {
+        v = _mm512_add_pd(v, rhs.v);
+        return *this;
+      }
+      Vec& operator-=(const Vec& rhs) {
+        v = _mm512_sub_pd(v, rhs.v);
+        return *this;
+      }
+      Vec& operator&=(const Vec& rhs) {
+        v = _mm512_and_pd(v, rhs.v);
+        return *this;
+      }
+      Vec& operator^=(const Vec& rhs) {
+        v = _mm512_xor_pd(v, rhs.v);
+        return *this;
+      }
+      Vec& operator|=(const Vec& rhs) {
+        v = _mm512_or_pd(v, rhs.v);
+        return *this;
+      }
+      Vec& operator&=(const __mmask8& rhs) {
+        v = _mm512_maskz_mov_pd(rhs, v);
+        return *this;
+      }
+
+      // Other operators
+      friend Vec max(Vec lhs, const Vec& rhs) {
+        lhs.v = _mm512_max_pd(lhs.v, rhs.v);
+        return lhs;
+      }
+      friend Vec min(Vec lhs, const Vec& rhs) {
+        lhs.v = _mm512_min_pd(lhs.v, rhs.v);
+        return lhs;
+      }
+
+      friend std::ostream& operator<<(std::ostream& os, const Vec& in) {
+        union {
+          VecType vec;
+          ValueType val[N];
+        };
+        vec = in.v;
+        for (Integer i = 0; i < N; i++) os << val[i] << ' ';
+        return os;
+      }
+      friend Vec approx_rsqrt(const Vec& x) {
+        Vec r;
+        r.v = _mm512_cvtps_pd(_mm256_rsqrt_ps(_mm512_cvtpd_ps(x.v)));
+        return r;
+      }
+
+      template <class Vec1, class Vec2> friend Vec1 reinterpret(const Vec2& x);
+      template <class Vec> friend Vec RoundReal2Real(const Vec& x);
+      template <class Vec> friend void sincos_intrin(Vec& sinx, Vec& cosx, const Vec& x);
+      template <class Vec> friend void exp_intrin(Vec& expx, const Vec& x);
+
+    private:
+
+      VecType v;
+  };
+
+  template <> inline Vec<int64_t,8> reinterpret<Vec<int64_t,8>,Vec<double,8>>(const Vec<double,8>& x){
+    union {
+      Vec<int64_t,8> r;
+      __m512i y;
+    };
+    y = _mm512_castpd_si512(x.v);
+    return r;
+  }
+
+  template <> inline Vec<double,8> RoundReal2Real(const Vec<double,8>& x) {
+    Vec<double,8> r;
+    // TODO: need double check
+    r.v = _mm512_roundscale_pd(x.v,_MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
+    return r;
+  }
+
+  #ifdef SCTL_HAVE_SVML
+  template <> inline void sincos_intrin(Vec<double,8>& sinx, Vec<double,8>& cosx, const Vec<double,8>& x) {
+    sinx.v = _mm512_sin_pd(x.v);
+    cosx.v = _mm512_cos_pd(x.v);
+  }
+
+  template <> inline void exp_intrin(Vec<double,8>& expx, const Vec<double,8>& x) {
+    expx.v = _mm512_exp_pd(x.v);
+  }
+  #endif
+
+#endif
+
 }
 
 #endif  //_SCTL_VEC_WRAPPER_HPP_