intrin_wrapper.hpp 13 KB


  1. #ifndef _SCTL_INTRIN_WRAPPER_HPP_
  2. #define _SCTL_INTRIN_WRAPPER_HPP_
  3. #include SCTL_INCLUDE(math_utils.hpp)
  4. #include SCTL_INCLUDE(common.hpp)
  5. #include <cstdint>
  6. #ifdef __SSE__
  7. #include <xmmintrin.h>
  8. #endif
  9. #ifdef __SSE2__
  10. #include <emmintrin.h>
  11. #endif
  12. #ifdef __SSE3__
  13. #include <pmmintrin.h>
  14. #endif
  15. #ifdef __AVX__
  16. #include <immintrin.h>
  17. #endif
  18. #if defined(__MIC__)
  19. #include <immintrin.h>
  20. #endif
  21. // TODO: Check alignment when SCTL_MEMDEBUG is defined
  22. // TODO: Replace pointers with iterators
  23. namespace SCTL_NAMESPACE {
  24. template <class T> inline T zero_intrin() { return (T)0; }
  25. template <class T, class Real> inline T set_intrin(const Real& a) { return a; }
  26. template <class T, class Real> inline T load_intrin(Real const* a) { return a[0]; }
  27. template <class T, class Real> inline T bcast_intrin(Real const* a) { return a[0]; }
  28. template <class T, class Real> inline void store_intrin(Real* a, const T& b) { a[0] = b; }
  29. template <class T> inline T mul_intrin(const T& a, const T& b) { return a * b; }
  30. template <class T> inline T add_intrin(const T& a, const T& b) { return a + b; }
  31. template <class T> inline T sub_intrin(const T& a, const T& b) { return a - b; }
  32. template <class T> inline T cmplt_intrin(const T& a, const T& b) {
  33. T r = 0;
  34. uint8_t* r_ = reinterpret_cast<uint8_t*>(&r);
  35. if (a < b)
  36. for (int i = 0; i < (int)sizeof(T); i++) r_[i] = ~(uint8_t)0;
  37. return r;
  38. }
  39. template <class T> inline T and_intrin(const T& a, const T& b) {
  40. T r = 0;
  41. const uint8_t* a_ = reinterpret_cast<const uint8_t*>(&a);
  42. const uint8_t* b_ = reinterpret_cast<const uint8_t*>(&b);
  43. uint8_t* r_ = reinterpret_cast<uint8_t*>(&r);
  44. for (int i = 0; i < (int)sizeof(T); i++) r_[i] = a_[i] & b_[i];
  45. return r;
  46. }
  47. template <class T> inline T rsqrt_approx_intrin(const T& r2) {
  48. if (r2 != 0) return 1.0 / sqrt<T>(r2);
  49. return 0;
  50. }
  51. template <class T, class Real> inline void rsqrt_newton_intrin(T& rinv, const T& r2, const Real& nwtn_const) { rinv = rinv * (nwtn_const - r2 * rinv * rinv); }
  52. template <class T> inline T rsqrt_single_intrin(const T& r2) {
  53. if (r2 != 0) return 1.0 / sqrt<T>(r2);
  54. return 0;
  55. }
  56. template <class T> inline T max_intrin(const T& a, const T& b) {
  57. if (a > b)
  58. return a;
  59. else
  60. return b;
  61. }
  62. template <class T> inline T min_intrin(const T& a, const T& b) {
  63. if (a > b)
  64. return b;
  65. else
  66. return a;
  67. }
  68. template <class T> inline T sin_intrin(const T& t) { return sin<T>(t); }
  69. template <class T> inline T cos_intrin(const T& t) { return cos<T>(t); }
  70. #ifdef __SSE3__
  71. template <> inline __m128 zero_intrin() { return _mm_setzero_ps(); }
  72. template <> inline __m128d zero_intrin() { return _mm_setzero_pd(); }
  73. template <> inline __m128 set_intrin(const float& a) { return _mm_set1_ps(a); }
  74. template <> inline __m128d set_intrin(const double& a) { return _mm_set1_pd(a); }
  75. template <> inline __m128 load_intrin(float const* a) { return _mm_load_ps(a); }
  76. template <> inline __m128d load_intrin(double const* a) { return _mm_load_pd(a); }
  77. template <> inline __m128 bcast_intrin(float const* a) { return _mm_set1_ps(a[0]); }
  78. template <> inline __m128d bcast_intrin(double const* a) { return _mm_load1_pd(a); }
  79. template <> inline void store_intrin(float* a, const __m128& b) { return _mm_store_ps(a, b); }
  80. template <> inline void store_intrin(double* a, const __m128d& b) { return _mm_store_pd(a, b); }
  81. template <> inline __m128 mul_intrin(const __m128& a, const __m128& b) { return _mm_mul_ps(a, b); }
  82. template <> inline __m128d mul_intrin(const __m128d& a, const __m128d& b) { return _mm_mul_pd(a, b); }
  83. template <> inline __m128 add_intrin(const __m128& a, const __m128& b) { return _mm_add_ps(a, b); }
  84. template <> inline __m128d add_intrin(const __m128d& a, const __m128d& b) { return _mm_add_pd(a, b); }
  85. template <> inline __m128 sub_intrin(const __m128& a, const __m128& b) { return _mm_sub_ps(a, b); }
  86. template <> inline __m128d sub_intrin(const __m128d& a, const __m128d& b) { return _mm_sub_pd(a, b); }
  87. template <> inline __m128 cmplt_intrin(const __m128& a, const __m128& b) { return _mm_cmplt_ps(a, b); }
  88. template <> inline __m128d cmplt_intrin(const __m128d& a, const __m128d& b) { return _mm_cmplt_pd(a, b); }
  89. template <> inline __m128 and_intrin(const __m128& a, const __m128& b) { return _mm_and_ps(a, b); }
  90. template <> inline __m128d and_intrin(const __m128d& a, const __m128d& b) { return _mm_and_pd(a, b); }
  91. template <> inline __m128 rsqrt_approx_intrin(const __m128& r2) {
  92. // Approx inverse square root which returns zero for r2=0
  93. return _mm_andnot_ps(_mm_cmpeq_ps(r2, zero_intrin<__m128>()), _mm_rsqrt_ps(r2));
  94. }
  95. template <> inline __m128d rsqrt_approx_intrin(const __m128d& r2) {
  96. return _mm_cvtps_pd(rsqrt_approx_intrin(_mm_cvtpd_ps(r2)));
  97. }
  98. template <> inline void rsqrt_newton_intrin(__m128& rinv, const __m128& r2, const float& nwtn_const) {
  99. // Newton iteration: rinv = 0.5 rinv_approx ( 3 - r2 rinv_approx^2 )
  100. // We do not compute the product with 0.5 and this needs to be adjusted later
  101. rinv = mul_intrin(rinv, sub_intrin(set_intrin<__m128>(nwtn_const), mul_intrin(r2, mul_intrin(rinv, rinv))));
  102. }
  103. template <> inline void rsqrt_newton_intrin(__m128d& rinv, const __m128d& r2, const double& nwtn_const) {
  104. // Newton iteration: rinv = 0.5 rinv_approx ( 3 - r2 rinv_approx^2 )
  105. // We do not compute the product with 0.5 and this needs to be adjusted later
  106. rinv = mul_intrin(rinv, sub_intrin(set_intrin<__m128d>(nwtn_const), mul_intrin(r2, mul_intrin(rinv, rinv))));
  107. }
  108. template <> inline __m128 rsqrt_single_intrin(const __m128& r2) {
  109. __m128 rinv = rsqrt_approx_intrin(r2);
  110. rsqrt_newton_intrin(rinv, r2, (float)3.0);
  111. return rinv;
  112. }
  113. template <> inline __m128d rsqrt_single_intrin(const __m128d& r2) {
  114. return _mm_cvtps_pd(rsqrt_single_intrin(_mm_cvtpd_ps(r2)));
  115. }
  116. template <> inline __m128 max_intrin(const __m128& a, const __m128& b) { return _mm_max_ps(a, b); }
  117. template <> inline __m128d max_intrin(const __m128d& a, const __m128d& b) { return _mm_max_pd(a, b); }
  118. template <> inline __m128 min_intrin(const __m128& a, const __m128& b) { return _mm_min_ps(a, b); }
  119. template <> inline __m128d min_intrin(const __m128d& a, const __m128d& b) { return _mm_min_pd(a, b); }
  120. #ifdef SCTL_HAVE_INTEL_SVML
  121. template <> inline __m128 sin_intrin(const __m128& t) { return _mm_sin_ps(t); }
  122. template <> inline __m128 cos_intrin(const __m128& t) { return _mm_cos_ps(t); }
  123. template <> inline __m128d sin_intrin(const __m128d& t) { return _mm_sin_pd(t); }
  124. template <> inline __m128d cos_intrin(const __m128d& t) { return _mm_cos_pd(t); }
  125. #else
  126. template <> inline __m128 sin_intrin(const __m128& t_) {
  127. union {
  128. float e[4];
  129. __m128 d;
  130. } t;
  131. store_intrin(t.e, t_);
  132. return _mm_set_ps(sin<float>(t.e[3]), sin<float>(t.e[2]), sin<float>(t.e[1]), sin<float>(t.e[0]));
  133. }
  134. template <> inline __m128 cos_intrin(const __m128& t_) {
  135. union {
  136. float e[4];
  137. __m128 d;
  138. } t;
  139. store_intrin(t.e, t_);
  140. return _mm_set_ps(cos<float>(t.e[3]), cos<float>(t.e[2]), cos<float>(t.e[1]), cos<float>(t.e[0]));
  141. }
  142. template <> inline __m128d sin_intrin(const __m128d& t_) {
  143. union {
  144. double e[2];
  145. __m128d d;
  146. } t;
  147. store_intrin(t.e, t_);
  148. return _mm_set_pd(sin<double>(t.e[1]), sin<double>(t.e[0]));
  149. }
  150. template <> inline __m128d cos_intrin(const __m128d& t_) {
  151. union {
  152. double e[2];
  153. __m128d d;
  154. } t;
  155. store_intrin(t.e, t_);
  156. return _mm_set_pd(cos<double>(t.e[1]), cos<double>(t.e[0]));
  157. }
  158. #endif
  159. #endif
  160. #ifdef __AVX__
  161. template <> inline __m256 zero_intrin() { return _mm256_setzero_ps(); }
  162. template <> inline __m256d zero_intrin() { return _mm256_setzero_pd(); }
  163. template <> inline __m256 set_intrin(const float& a) { return _mm256_set1_ps(a); }
  164. template <> inline __m256d set_intrin(const double& a) { return _mm256_set1_pd(a); }
  165. template <> inline __m256 load_intrin(float const* a) { return _mm256_load_ps(a); }
  166. template <> inline __m256d load_intrin(double const* a) { return _mm256_load_pd(a); }
  167. template <> inline __m256 bcast_intrin(float const* a) { return _mm256_broadcast_ss(a); }
  168. template <> inline __m256d bcast_intrin(double const* a) { return _mm256_broadcast_sd(a); }
  169. template <> inline void store_intrin(float* a, const __m256& b) { return _mm256_store_ps(a, b); }
  170. template <> inline void store_intrin(double* a, const __m256d& b) { return _mm256_store_pd(a, b); }
  171. template <> inline __m256 mul_intrin(const __m256& a, const __m256& b) { return _mm256_mul_ps(a, b); }
  172. template <> inline __m256d mul_intrin(const __m256d& a, const __m256d& b) { return _mm256_mul_pd(a, b); }
  173. template <> inline __m256 add_intrin(const __m256& a, const __m256& b) { return _mm256_add_ps(a, b); }
  174. template <> inline __m256d add_intrin(const __m256d& a, const __m256d& b) { return _mm256_add_pd(a, b); }
  175. template <> inline __m256 sub_intrin(const __m256& a, const __m256& b) { return _mm256_sub_ps(a, b); }
  176. template <> inline __m256d sub_intrin(const __m256d& a, const __m256d& b) { return _mm256_sub_pd(a, b); }
  177. template <> inline __m256 cmplt_intrin(const __m256& a, const __m256& b) { return _mm256_cmp_ps(a, b, _CMP_LT_OS); }
  178. template <> inline __m256d cmplt_intrin(const __m256d& a, const __m256d& b) { return _mm256_cmp_pd(a, b, _CMP_LT_OS); }
  179. template <> inline __m256 and_intrin(const __m256& a, const __m256& b) { return _mm256_and_ps(a, b); }
  180. template <> inline __m256d and_intrin(const __m256d& a, const __m256d& b) { return _mm256_and_pd(a, b); }
  181. template <> inline __m256 rsqrt_approx_intrin(const __m256& r2) {
  182. // Approx inverse square root which returns zero for r2=0
  183. return _mm256_andnot_ps(_mm256_cmp_ps(r2, zero_intrin<__m256>(), _CMP_EQ_OS), _mm256_rsqrt_ps(r2));
  184. }
  185. template <> inline __m256d rsqrt_approx_intrin(const __m256d& r2) {
  186. return _mm256_cvtps_pd(rsqrt_approx_intrin(_mm256_cvtpd_ps(r2)));
  187. }
  188. template <> inline void rsqrt_newton_intrin(__m256& rinv, const __m256& r2, const float& nwtn_const) {
  189. // Newton iteration: rinv = 0.5 rinv_approx ( 3 - r2 rinv_approx^2 )
  190. // We do not compute the product with 0.5 and this needs to be adjusted later
  191. rinv = mul_intrin(rinv, sub_intrin(set_intrin<__m256>(nwtn_const), mul_intrin(r2, mul_intrin(rinv, rinv))));
  192. }
  193. template <> inline void rsqrt_newton_intrin(__m256d& rinv, const __m256d& r2, const double& nwtn_const) {
  194. // Newton iteration: rinv = 0.5 rinv_approx ( 3 - r2 rinv_approx^2 )
  195. // We do not compute the product with 0.5 and this needs to be adjusted later
  196. rinv = mul_intrin(rinv, sub_intrin(set_intrin<__m256d>(nwtn_const), mul_intrin(r2, mul_intrin(rinv, rinv))));
  197. }
  198. template <> inline __m256 rsqrt_single_intrin(const __m256& r2) {
  199. __m256 rinv = rsqrt_approx_intrin(r2);
  200. rsqrt_newton_intrin(rinv, r2, (float)3.0);
  201. return rinv;
  202. }
  203. template <> inline __m256d rsqrt_single_intrin(const __m256d& r2) {
  204. return _mm256_cvtps_pd(rsqrt_single_intrin(_mm256_cvtpd_ps(r2)));
  205. }
  206. template <> inline __m256 max_intrin(const __m256& a, const __m256& b) { return _mm256_max_ps(a, b); }
  207. template <> inline __m256d max_intrin(const __m256d& a, const __m256d& b) { return _mm256_max_pd(a, b); }
  208. template <> inline __m256 min_intrin(const __m256& a, const __m256& b) { return _mm256_min_ps(a, b); }
  209. template <> inline __m256d min_intrin(const __m256d& a, const __m256d& b) { return _mm256_min_pd(a, b); }
  210. #ifdef SCTL_HAVE_INTEL_SVML
  211. template <> inline __m256 sin_intrin(const __m256& t) { return _mm256_sin_ps(t); }
  212. template <> inline __m256 cos_intrin(const __m256& t) { return _mm256_cos_ps(t); }
  213. template <> inline __m256d sin_intrin(const __m256d& t) { return _mm256_sin_pd(t); }
  214. template <> inline __m256d cos_intrin(const __m256d& t) { return _mm256_cos_pd(t); }
  215. #else
  216. template <> inline __m256 sin_intrin(const __m256& t_) {
  217. union {
  218. float e[8];
  219. __m256 d;
  220. } t;
  221. store_intrin(t.e, t_); // t.d=t_;
  222. return _mm256_set_ps(sin<float>(t.e[7]), sin<float>(t.e[6]), sin<float>(t.e[5]), sin<float>(t.e[4]), sin<float>(t.e[3]), sin<float>(t.e[2]), sin<float>(t.e[1]), sin<float>(t.e[0]));
  223. }
  224. template <> inline __m256 cos_intrin(const __m256& t_) {
  225. union {
  226. float e[8];
  227. __m256 d;
  228. } t;
  229. store_intrin(t.e, t_); // t.d=t_;
  230. return _mm256_set_ps(cos<float>(t.e[7]), cos<float>(t.e[6]), cos<float>(t.e[5]), cos<float>(t.e[4]), cos<float>(t.e[3]), cos<float>(t.e[2]), cos<float>(t.e[1]), cos<float>(t.e[0]));
  231. }
  232. template <> inline __m256d sin_intrin(const __m256d& t_) {
  233. union {
  234. double e[4];
  235. __m256d d;
  236. } t;
  237. store_intrin(t.e, t_); // t.d=t_;
  238. return _mm256_set_pd(sin<double>(t.e[3]), sin<double>(t.e[2]), sin<double>(t.e[1]), sin<double>(t.e[0]));
  239. }
  240. template <> inline __m256d cos_intrin(const __m256d& t_) {
  241. union {
  242. double e[4];
  243. __m256d d;
  244. } t;
  245. store_intrin(t.e, t_); // t.d=t_;
  246. return _mm256_set_pd(cos<double>(t.e[3]), cos<double>(t.e[2]), cos<double>(t.e[1]), cos<double>(t.e[0]));
  247. }
  248. #endif
  249. #endif
  250. template <class VEC, class Real> inline VEC rsqrt_intrin0(VEC r2) {
  251. VEC rinv;
  252. rinv = rsqrt_approx_intrin(r2);
  253. return rinv;
  254. }
  255. template <class VEC, class Real> inline VEC rsqrt_intrin1(VEC r2) {
  256. Real const_nwtn1 = 3;
  257. VEC rinv;
  258. rinv = rsqrt_approx_intrin(r2);
  259. rsqrt_newton_intrin(rinv, r2, const_nwtn1);
  260. return rinv;
  261. }
  262. template <class VEC, class Real> inline VEC rsqrt_intrin2(VEC r2) {
  263. Real const_nwtn1 = 3;
  264. Real const_nwtn2 = 12;
  265. VEC rinv;
  266. rinv = rsqrt_approx_intrin(r2);
  267. rsqrt_newton_intrin(rinv, r2, const_nwtn1);
  268. rsqrt_newton_intrin(rinv, r2, const_nwtn2);
  269. return rinv;
  270. }
  271. template <class VEC, class Real> inline VEC rsqrt_intrin3(VEC r2) {
  272. Real const_nwtn1 = 3;
  273. Real const_nwtn2 = 12;
  274. Real const_nwtn3 = 768;
  275. VEC rinv = rsqrt_approx_intrin(r2);
  276. rsqrt_newton_intrin(rinv, r2, const_nwtn1);
  277. rsqrt_newton_intrin(rinv, r2, const_nwtn2);
  278. rsqrt_newton_intrin(rinv, r2, const_nwtn3);
  279. return rinv;
  280. }
  281. }
  282. #endif //_SCTL_INTRIN_WRAPPER_HPP_