vec.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. #ifndef _SCTL_VEC_WRAPPER_HPP_
  2. #define _SCTL_VEC_WRAPPER_HPP_
  3. #include <sctl/common.hpp>
  4. #include SCTL_INCLUDE(intrin-wrapper.hpp)
  5. #include <cassert>
  6. #include <cstdint>
  7. #include <ostream>
  8. namespace SCTL_NAMESPACE {
  9. #if defined(__AVX512__) || defined(__AVX512F__)
  10. static_assert(SCTL_ALIGN_BYTES >= 64, "Insufficient memory alignment for SIMD vector types");
  11. template <class ScalarType> constexpr Integer DefaultVecLen() { return 64/sizeof(ScalarType); }
  12. #elif defined(__AVX__)
  13. static_assert(SCTL_ALIGN_BYTES >= 32, "Insufficient memory alignment for SIMD vector types");
  14. template <class ScalarType> constexpr Integer DefaultVecLen() { return 32/sizeof(ScalarType); }
  15. #elif defined(__SSE4_2__)
  16. static_assert(SCTL_ALIGN_BYTES >= 16, "Insufficient memory alignment for SIMD vector types");
  17. template <class ScalarType> constexpr Integer DefaultVecLen() { return 16/sizeof(ScalarType); }
  18. #else
  19. static_assert(SCTL_ALIGN_BYTES >= 8, "Insufficient memory alignment for SIMD vector types");
  20. template <class ScalarType> constexpr Integer DefaultVecLen() { return 1; }
  21. #endif
  22. template <class ValueType, Integer N = DefaultVecLen<ValueType>()> class alignas(sizeof(ValueType) * N) Vec {
  23. public:
  24. using ScalarType = ValueType;
  25. using VData = VecData<ScalarType,N>;
  26. using MaskType = Mask<VData>;
  27. static constexpr Integer Size() {
  28. return N;
  29. }
  30. static inline Vec Zero() {
  31. Vec r;
  32. r.v = zero_intrin<VData>();
  33. return r;
  34. }
  35. static inline Vec Load1(ScalarType const* p) {
  36. Vec r;
  37. r.v = load1_intrin<VData>(p);
  38. return r;
  39. }
  40. static inline Vec Load(ScalarType const* p) {
  41. Vec r;
  42. r.v = loadu_intrin<VData>(p);
  43. return r;
  44. }
  45. static inline Vec LoadAligned(ScalarType const* p) {
  46. Vec r;
  47. r.v = load_intrin<VData>(p);
  48. return r;
  49. }
  50. Vec() = default;
  51. Vec(const Vec&) = default;
  52. Vec& operator=(const Vec&) = default;
  53. ~Vec() = default;
  54. inline Vec(const VData& v_) : v(v_) {}
  55. inline Vec(const ScalarType& a) : Vec(set1_intrin<VData>(a)) {}
  56. template <class T,class ...T1> inline Vec(T x, T1... args) : Vec(InitVec<T1...>::template apply<ScalarType>((ScalarType)x,args...)) {}
  57. inline void Store(ScalarType* p) const {
  58. storeu_intrin(p,v);
  59. }
  60. inline void StoreAligned(ScalarType* p) const {
  61. store_intrin(p,v);
  62. }
  63. inline void StreamStoreAligned(ScalarType* p) const {
  64. stream_store_intrin(p,v);
  65. }
  66. // Conversion operators
  67. friend inline Mask<VData> convert2mask(const Vec& a) {
  68. return convert_vec2mask_intrin(a.v);
  69. }
  70. friend inline Vec RoundReal2Real(const Vec& x) {
  71. return round_real2real_intrin(x.v);
  72. }
  73. template <class IntVec, class RealVec> friend IntVec RoundReal2Int(const RealVec& x);
  74. template <class RealVec, class IntVec> friend RealVec ConvertInt2Real(const IntVec& x);
  75. // Element access
  76. inline ScalarType operator[](Integer i) const {
  77. return extract_intrin(v,i);
  78. }
  79. inline void insert(Integer i, ScalarType value) {
  80. insert_intrin(v,i,value);
  81. }
  82. // Arithmetic operators
  83. inline Vec operator+() const {
  84. return *this;
  85. }
  86. inline Vec operator-() const {
  87. return unary_minus_intrin(v); // Zero() - (*this);
  88. }
  89. friend inline Vec operator*(const Vec& a, const Vec& b) {
  90. return mul_intrin(a.v, b.v);
  91. }
  92. friend inline Vec operator/(const Vec& a, const Vec& b) {
  93. return div_intrin(a.v, b.v);
  94. }
  95. friend inline Vec operator+(const Vec& a, const Vec& b) {
  96. return add_intrin(a.v, b.v);
  97. }
  98. friend inline Vec operator-(const Vec& a, const Vec& b) {
  99. return sub_intrin(a.v, b.v);
  100. }
  101. friend inline Vec FMA(const Vec& a, const Vec& b, const Vec& c) {
  102. return fma_intrin(a.v, b.v, c.v);
  103. }
  104. // Comparison operators
  105. friend inline Mask<VData> operator< (const Vec& a, const Vec& b) {
  106. return comp_intrin<ComparisonType::lt>(a.v, b.v);
  107. }
  108. friend inline Mask<VData> operator<=(const Vec& a, const Vec& b) {
  109. return comp_intrin<ComparisonType::le>(a.v, b.v);
  110. }
  111. friend inline Mask<VData> operator>=(const Vec& a, const Vec& b) {
  112. return comp_intrin<ComparisonType::ge>(a.v, b.v);
  113. }
  114. friend inline Mask<VData> operator> (const Vec& a, const Vec& b) {
  115. return comp_intrin<ComparisonType::gt>(a.v, b.v);
  116. }
  117. friend inline Mask<VData> operator==(const Vec& a, const Vec& b) {
  118. return comp_intrin<ComparisonType::eq>(a.v, b.v);
  119. }
  120. friend inline Mask<VData> operator!=(const Vec& a, const Vec& b) {
  121. return comp_intrin<ComparisonType::ne>(a.v, b.v);
  122. }
  123. friend inline Vec select(const Mask<VData>& m, const Vec& a, const Vec& b) {
  124. return select_intrin(m, a.v, b.v);
  125. }
  126. // Bitwise operators
  127. inline Vec operator~() const {
  128. return not_intrin(v);
  129. }
  130. friend inline Vec operator&(const Vec& a, const Vec& b) {
  131. return and_intrin(a.v, b.v);
  132. }
  133. friend inline Vec operator^(const Vec& a, const Vec& b) {
  134. return xor_intrin(a.v, b.v);
  135. }
  136. friend inline Vec operator|(const Vec& a, const Vec& b) {
  137. return or_intrin(a.v, b.v);
  138. }
  139. friend inline Vec AndNot(const Vec& a, const Vec& b) { // return a & ~b
  140. return andnot_intrin(a.v, b.v);
  141. }
  142. // Bitshift
  143. friend inline Vec operator<<(const Vec& lhs, const Integer& rhs) {
  144. return bitshiftleft_intrin(lhs.v, rhs);
  145. }
  146. friend inline Vec operator>>(const Vec& lhs, const Integer& rhs) {
  147. return bitshiftright_intrin(lhs.v, rhs);
  148. }
  149. // Assignment operators
  150. inline Vec& operator=(const ScalarType& a) {
  151. v = set1_intrin<VData>(a);
  152. return *this;
  153. }
  154. inline Vec& operator*=(const Vec& rhs) {
  155. v = mul_intrin(v, rhs.v);
  156. return *this;
  157. }
  158. inline Vec& operator/=(const Vec& rhs) {
  159. v = div_intrin(v, rhs.v);
  160. return *this;
  161. }
  162. inline Vec& operator+=(const Vec& rhs) {
  163. v = add_intrin(v, rhs.v);
  164. return *this;
  165. }
  166. inline Vec& operator-=(const Vec& rhs) {
  167. v = sub_intrin(v, rhs.v);
  168. return *this;
  169. }
  170. inline Vec& operator&=(const Vec& rhs) {
  171. v = and_intrin(v, rhs.v);
  172. return *this;
  173. }
  174. inline Vec& operator^=(const Vec& rhs) {
  175. v = xor_intrin(v, rhs.v);
  176. return *this;
  177. }
  178. inline Vec& operator|=(const Vec& rhs) {
  179. v = or_intrin(v, rhs.v);
  180. return *this;
  181. }
  182. // Other operators
  183. friend inline Vec max(const Vec& lhs, const Vec& rhs) {
  184. return max_intrin(lhs.v, rhs.v);
  185. }
  186. friend inline Vec min(const Vec& lhs, const Vec& rhs) {
  187. return min_intrin(lhs.v, rhs.v);
  188. }
  189. // Special functions
  190. template <Integer digits, class RealVec> friend RealVec approx_rsqrt(const RealVec& x);
  191. template <Integer digits, class RealVec> friend RealVec approx_rsqrt(const RealVec& x, const typename RealVec::MaskType& m);
  192. friend inline void sincos(Vec& sinx, Vec& cosx, const Vec& x) {
  193. sincos_intrin(sinx.v, cosx.v, x.v);
  194. }
  195. template <Integer digits, class RealVec> friend void approx_sincos(RealVec& sinx, RealVec& cosx, const RealVec& x);
  196. friend inline Vec exp(const Vec& x) {
  197. return exp_intrin(x.v);
  198. }
  199. template <Integer digits, class RealVec> friend RealVec approx_exp(const RealVec& x);
  200. //template <class Vec1, class Vec2> friend Vec1 reinterpret(const Vec2& x);
  201. //template <class Vec> friend Vec RoundReal2Real(const Vec& x);
  202. //template <class Vec> friend void exp_intrin(Vec& expx, const Vec& x);
  203. // Print
  204. friend inline std::ostream& operator<<(std::ostream& os, const Vec& in) {
  205. for (Integer i = 0; i < Size(); i++) os << in[i] << ' ';
  206. return os;
  207. }
  208. inline void set(const VData& v_) { v = v_; }
  209. inline const VData& get() const { return v; }
  210. private:
  211. template <class T, class... T2> struct InitVec {
  212. template <class... T1> static inline VData apply(T1... start, T x, T2... rest) {
  213. return InitVec<T2...>::template apply<ScalarType, T1...>(start..., (ScalarType)x, rest...);
  214. }
  215. };
  216. template <class T> struct InitVec<T> {
  217. template <class... T1> static inline VData apply(T1... start, T x) {
  218. return set_intrin<VData>(start..., (ScalarType)x);
  219. }
  220. };
  221. VData v;
  222. };
  223. // Conversion operators
  224. template <class RealVec, class IntVec> inline RealVec ConvertInt2Real(const IntVec& x) {
  225. return convert_int2real_intrin<typename RealVec::VData>(x.v);
  226. }
  227. template <class IntVec, class RealVec> inline IntVec RoundReal2Int(const RealVec& x) {
  228. return round_real2int_intrin<typename IntVec::VData>(x.v);
  229. }
  230. template <class MaskType> inline Vec<typename MaskType::ScalarType,MaskType::Size> convert2vec(const MaskType& a) {
  231. return convert_mask2vec_intrin(a);
  232. }
  233. // Special functions
  234. template <Integer digits, class RealVec> inline RealVec approx_rsqrt(const RealVec& x) {
  235. static constexpr Integer digits_ = (digits==-1 ? (Integer)(TypeTraits<typename RealVec::ScalarType>::SigBits*0.3010299957) : digits);
  236. return rsqrt_approx_intrin<digits_, typename RealVec::VData>::eval(x.v);
  237. }
  238. template <Integer digits, class RealVec> inline RealVec approx_rsqrt(const RealVec& x, const typename RealVec::MaskType& m) {
  239. static constexpr Integer digits_ = (digits==-1 ? (Integer)(TypeTraits<typename RealVec::ScalarType>::SigBits*0.3010299957) : digits);
  240. return rsqrt_approx_intrin<digits_, typename RealVec::VData>::eval(x.v, m);
  241. }
  242. template <Integer digits, class RealVec> inline RealVec approx_sqrt(const RealVec& x) {
  243. return x*approx_rsqrt<digits>(x);
  244. }
  245. template <Integer digits, class RealVec> inline RealVec approx_sqrt(const RealVec& x, const typename RealVec::MaskType& m) {
  246. return x*approx_rsqrt<digits>(x, m);
  247. }
  248. template <Integer digits, class RealVec> inline void approx_sincos(RealVec& sinx, RealVec& cosx, const RealVec& x) {
  249. constexpr Integer ORDER = (digits>1?digits>9?digits>14?digits>17?digits-1:digits:digits+1:digits+2:1);
  250. if (digits == -1 || ORDER > 20) sincos(sinx, cosx, x);
  251. else approx_sincos_intrin<ORDER>(sinx.v, cosx.v, x.v);
  252. }
  253. template <Integer digits, class RealVec> inline RealVec approx_exp(const RealVec& x) {
  254. constexpr Integer ORDER = digits;
  255. if (digits == -1 || ORDER > 13) return exp(x);
  256. else return approx_exp_intrin<ORDER>(x.v);
  257. }
  258. // Other operators
  259. template <class ValueType> inline void printb(const ValueType& x) { // print binary
  260. union {
  261. ValueType v;
  262. uint8_t c[sizeof(ValueType)];
  263. } u = {x};
  264. //std::cout<<std::setw(10)<<x<<' ';
  265. for (Integer i = 0; i < (Integer)sizeof(ValueType); i++) {
  266. for (Integer j = 0; j < 8; j++) {
  267. std::cout<<((u.c[i] & (1U<<j))?'1':'0');
  268. }
  269. }
  270. std::cout<<'\n';
  271. }
  272. }
  273. #endif //_SCTL_VEC_WRAPPER_HPP_