intrin_wrapper.hpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. #ifndef _SCTL_INTRIN_WRAPPER_HPP_
  2. #define _SCTL_INTRIN_WRAPPER_HPP_
  3. #include SCTL_INCLUDE(math_utils.hpp)
  4. #include SCTL_INCLUDE(common.hpp)
  5. #include <cstdint>
  6. #ifdef __SSE__
  7. #include <xmmintrin.h>
  8. #endif
  9. #ifdef __SSE2__
  10. #include <emmintrin.h>
  11. #endif
  12. #ifdef __SSE3__
  13. #include <pmmintrin.h>
  14. #endif
  15. #ifdef __AVX__
  16. #include <immintrin.h>
  17. #endif
  18. #if defined(__MIC__)
  19. #include <immintrin.h>
  20. #endif
  21. // TODO: Check alignment when SCTL_MEMDEBUG is defined
  22. // TODO: Replace pointers with iterators
  23. namespace SCTL_NAMESPACE {
  24. template <class T> inline T zero_intrin() { return (T)0; }
  25. template <class T, class Real> inline T set_intrin(const Real& a) { return a; }
  26. template <class T, class Real> inline T load_intrin(Real const* a) { return a[0]; }
  27. template <class T, class Real> inline T bcast_intrin(Real const* a) { return a[0]; }
  28. template <class T, class Real> inline void store_intrin(Real* a, const T& b) { a[0] = b; }
  29. template <class T> inline T mul_intrin(const T& a, const T& b) { return a * b; }
  30. template <class T> inline T add_intrin(const T& a, const T& b) { return a + b; }
  31. template <class T> inline T sub_intrin(const T& a, const T& b) { return a - b; }
  32. template <class T> inline T cmplt_intrin(const T& a, const T& b) {
  33. T r = 0;
  34. uint8_t* r_ = reinterpret_cast<uint8_t*>(&r);
  35. if (a < b)
  36. for (int i = 0; i < (int)sizeof(T); i++) r_[i] = ~(uint8_t)0;
  37. return r;
  38. }
  39. template <class T> inline T and_intrin(const T& a, const T& b) {
  40. T r = 0;
  41. const uint8_t* a_ = reinterpret_cast<const uint8_t*>(&a);
  42. const uint8_t* b_ = reinterpret_cast<const uint8_t*>(&b);
  43. uint8_t* r_ = reinterpret_cast<uint8_t*>(&r);
  44. for (int i = 0; i < (int)sizeof(T); i++) r_[i] = a_[i] & b_[i];
  45. return r;
  46. }
  47. template <class T> inline T rsqrt_approx_intrin(const T& r2) {
  48. if (r2 != 0) return 1.0 / sqrt<T>(r2);
  49. return 0;
  50. }
  51. template <class T, class Real> inline void rsqrt_newton_intrin(T& rinv, const T& r2, const Real& nwtn_const) { rinv = rinv * (nwtn_const - r2 * rinv * rinv); }
  52. template <class T> inline T rsqrt_single_intrin(const T& r2) {
  53. if (r2 != 0) return 1.0 / sqrt<T>(r2);
  54. return 0;
  55. }
  56. template <class T> inline T max_intrin(const T& a, const T& b) {
  57. if (a > b)
  58. return a;
  59. else
  60. return b;
  61. }
  62. template <class T> inline T min_intrin(const T& a, const T& b) {
  63. if (a > b)
  64. return b;
  65. else
  66. return a;
  67. }
  68. template <class T> inline T sin_intrin(const T& t) { return sin<T>(t); }
  69. template <class T> inline T cos_intrin(const T& t) { return cos<T>(t); }
  70. #ifdef __SSE3__
  71. template <> inline __m128 zero_intrin() { return _mm_setzero_ps(); }
  72. template <> inline __m128d zero_intrin() { return _mm_setzero_pd(); }
  73. template <> inline __m128 set_intrin(const float& a) { return _mm_set_ps1(a); }
  74. template <> inline __m128d set_intrin(const double& a) { return _mm_set_pd1(a); }
  75. template <> inline __m128 load_intrin(float const* a) { return _mm_load_ps(a); }
  76. template <> inline __m128d load_intrin(double const* a) { return _mm_load_pd(a); }
  77. template <> inline __m128 bcast_intrin(float const* a) { return _mm_set_ps1(a[0]); }
  78. template <> inline __m128d bcast_intrin(double const* a) { return _mm_load_pd1(a); }
  79. template <> inline void store_intrin(float* a, const __m128& b) { return _mm_store_ps(a, b); }
  80. template <> inline void store_intrin(double* a, const __m128d& b) { return _mm_store_pd(a, b); }
  81. template <> inline __m128 mul_intrin(const __m128& a, const __m128& b) { return _mm_mul_ps(a, b); }
  82. template <> inline __m128d mul_intrin(const __m128d& a, const __m128d& b) { return _mm_mul_pd(a, b); }
  83. template <> inline __m128 add_intrin(const __m128& a, const __m128& b) { return _mm_add_ps(a, b); }
  84. template <> inline __m128d add_intrin(const __m128d& a, const __m128d& b) { return _mm_add_pd(a, b); }
  85. template <> inline __m128 sub_intrin(const __m128& a, const __m128& b) { return _mm_sub_ps(a, b); }
  86. template <> inline __m128d sub_intrin(const __m128d& a, const __m128d& b) { return _mm_sub_pd(a, b); }
  87. template <> inline __m128 cmplt_intrin(const __m128& a, const __m128& b) { return _mm_cmplt_ps(a, b); }
  88. template <> inline __m128d cmplt_intrin(const __m128d& a, const __m128d& b) { return _mm_cmplt_pd(a, b); }
  89. template <> inline __m128 and_intrin(const __m128& a, const __m128& b) { return _mm_and_ps(a, b); }
  90. template <> inline __m128d and_intrin(const __m128d& a, const __m128d& b) { return _mm_and_pd(a, b); }
  91. template <> inline __m128 rsqrt_approx_intrin(const __m128& r2) {
  92. #define VEC_INTRIN __m128
  93. #define RSQRT_INTRIN(a) _mm_rsqrt_ps(a)
  94. #define CMPEQ_INTRIN(a, b) _mm_cmpeq_ps(a, b)
  95. #define ANDNOT_INTRIN(a, b) _mm_andnot_ps(a, b)
  96. // Approx inverse square root which returns zero for r2=0
  97. return ANDNOT_INTRIN(CMPEQ_INTRIN(r2, zero_intrin<VEC_INTRIN>()), RSQRT_INTRIN(r2));
  98. #undef VEC_INTRIN
  99. #undef RSQRT_INTRIN
  100. #undef CMPEQ_INTRIN
  101. #undef ANDNOT_INTRIN
  102. }
  103. template <> inline __m128d rsqrt_approx_intrin(const __m128d& r2) {
  104. #define PD2PS(a) _mm_cvtpd_ps(a)
  105. #define PS2PD(a) _mm_cvtps_pd(a)
  106. return PS2PD(rsqrt_approx_intrin(PD2PS(r2)));
  107. #undef PD2PS
  108. #undef PS2PD
  109. }
  110. template <> inline void rsqrt_newton_intrin(__m128& rinv, const __m128& r2, const float& nwtn_const) {
  111. #define VEC_INTRIN __m128
  112. // Newton iteration: rinv = 0.5 rinv_approx ( 3 - r2 rinv_approx^2 )
  113. // We do not compute the product with 0.5 and this needs to be adjusted later
  114. rinv = mul_intrin(rinv, sub_intrin(set_intrin<VEC_INTRIN>(nwtn_const), mul_intrin(r2, mul_intrin(rinv, rinv))));
  115. #undef VEC_INTRIN
  116. }
  117. template <> inline void rsqrt_newton_intrin(__m128d& rinv, const __m128d& r2, const double& nwtn_const) {
  118. #define VEC_INTRIN __m128d
  119. // Newton iteration: rinv = 0.5 rinv_approx ( 3 - r2 rinv_approx^2 )
  120. // We do not compute the product with 0.5 and this needs to be adjusted later
  121. rinv = mul_intrin(rinv, sub_intrin(set_intrin<VEC_INTRIN>(nwtn_const), mul_intrin(r2, mul_intrin(rinv, rinv))));
  122. #undef VEC_INTRIN
  123. }
  124. template <> inline __m128 rsqrt_single_intrin(const __m128& r2) {
  125. #define VEC_INTRIN __m128
  126. VEC_INTRIN rinv = rsqrt_approx_intrin(r2);
  127. rsqrt_newton_intrin(rinv, r2, (float)3.0);
  128. return rinv;
  129. #undef VEC_INTRIN
  130. }
  131. template <> inline __m128d rsqrt_single_intrin(const __m128d& r2) {
  132. #define PD2PS(a) _mm_cvtpd_ps(a)
  133. #define PS2PD(a) _mm_cvtps_pd(a)
  134. return PS2PD(rsqrt_single_intrin(PD2PS(r2)));
  135. #undef PD2PS
  136. #undef PS2PD
  137. }
  138. template <> inline __m128 max_intrin(const __m128& a, const __m128& b) { return _mm_max_ps(a, b); }
  139. template <> inline __m128d max_intrin(const __m128d& a, const __m128d& b) { return _mm_max_pd(a, b); }
  140. template <> inline __m128 min_intrin(const __m128& a, const __m128& b) { return _mm_min_ps(a, b); }
  141. template <> inline __m128d min_intrin(const __m128d& a, const __m128d& b) { return _mm_min_pd(a, b); }
  142. #ifdef SCTL_HAVE_INTEL_SVML
  143. template <> inline __m128 sin_intrin(const __m128& t) { return _mm_sin_ps(t); }
  144. template <> inline __m128 cos_intrin(const __m128& t) { return _mm_cos_ps(t); }
  145. template <> inline __m128d sin_intrin(const __m128d& t) { return _mm_sin_pd(t); }
  146. template <> inline __m128d cos_intrin(const __m128d& t) { return _mm_cos_pd(t); }
  147. #else
  148. template <> inline __m128 sin_intrin(const __m128& t_) {
  149. union {
  150. float e[4];
  151. __m128 d;
  152. } t;
  153. store_intrin(t.e, t_);
  154. return _mm_set_ps(sin<float>(t.e[3]), sin<float>(t.e[2]), sin<float>(t.e[1]), sin<float>(t.e[0]));
  155. }
  156. template <> inline __m128 cos_intrin(const __m128& t_) {
  157. union {
  158. float e[4];
  159. __m128 d;
  160. } t;
  161. store_intrin(t.e, t_);
  162. return _mm_set_ps(cos<float>(t.e[3]), cos<float>(t.e[2]), cos<float>(t.e[1]), cos<float>(t.e[0]));
  163. }
  164. template <> inline __m128d sin_intrin(const __m128d& t_) {
  165. union {
  166. double e[2];
  167. __m128d d;
  168. } t;
  169. store_intrin(t.e, t_);
  170. return _mm_set_pd(sin<double>(t.e[1]), sin<double>(t.e[0]));
  171. }
  172. template <> inline __m128d cos_intrin(const __m128d& t_) {
  173. union {
  174. double e[2];
  175. __m128d d;
  176. } t;
  177. store_intrin(t.e, t_);
  178. return _mm_set_pd(cos<double>(t.e[1]), cos<double>(t.e[0]));
  179. }
  180. #endif
  181. #endif
  182. #ifdef __AVX__
  183. template <> inline __m256 zero_intrin() { return _mm256_setzero_ps(); }
  184. template <> inline __m256d zero_intrin() { return _mm256_setzero_pd(); }
  185. template <> inline __m256 set_intrin(const float& a) { return _mm256_set_ps(a, a, a, a, a, a, a, a); }
  186. template <> inline __m256d set_intrin(const double& a) { return _mm256_set_pd(a, a, a, a); }
  187. template <> inline __m256 load_intrin(float const* a) { return _mm256_load_ps(a); }
  188. template <> inline __m256d load_intrin(double const* a) { return _mm256_load_pd(a); }
  189. template <> inline __m256 bcast_intrin(float const* a) { return _mm256_broadcast_ss(a); }
  190. template <> inline __m256d bcast_intrin(double const* a) { return _mm256_broadcast_sd(a); }
  191. template <> inline void store_intrin(float* a, const __m256& b) { return _mm256_store_ps(a, b); }
  192. template <> inline void store_intrin(double* a, const __m256d& b) { return _mm256_store_pd(a, b); }
  193. template <> inline __m256 mul_intrin(const __m256& a, const __m256& b) { return _mm256_mul_ps(a, b); }
  194. template <> inline __m256d mul_intrin(const __m256d& a, const __m256d& b) { return _mm256_mul_pd(a, b); }
  195. template <> inline __m256 add_intrin(const __m256& a, const __m256& b) { return _mm256_add_ps(a, b); }
  196. template <> inline __m256d add_intrin(const __m256d& a, const __m256d& b) { return _mm256_add_pd(a, b); }
  197. template <> inline __m256 sub_intrin(const __m256& a, const __m256& b) { return _mm256_sub_ps(a, b); }
  198. template <> inline __m256d sub_intrin(const __m256d& a, const __m256d& b) { return _mm256_sub_pd(a, b); }
  199. template <> inline __m256 cmplt_intrin(const __m256& a, const __m256& b) { return _mm256_cmp_ps(a, b, _CMP_LT_OS); }
  200. template <> inline __m256d cmplt_intrin(const __m256d& a, const __m256d& b) { return _mm256_cmp_pd(a, b, _CMP_LT_OS); }
  201. template <> inline __m256 and_intrin(const __m256& a, const __m256& b) { return _mm256_and_ps(a, b); }
  202. template <> inline __m256d and_intrin(const __m256d& a, const __m256d& b) { return _mm256_and_pd(a, b); }
  203. template <> inline __m256 rsqrt_approx_intrin(const __m256& r2) {
  204. #define VEC_INTRIN __m256
  205. #define RSQRT_INTRIN(a) _mm256_rsqrt_ps(a)
  206. #define CMPEQ_INTRIN(a, b) _mm256_cmp_ps(a, b, _CMP_EQ_OS)
  207. #define ANDNOT_INTRIN(a, b) _mm256_andnot_ps(a, b)
  208. // Approx inverse square root which returns zero for r2=0
  209. return ANDNOT_INTRIN(CMPEQ_INTRIN(r2, zero_intrin<VEC_INTRIN>()), RSQRT_INTRIN(r2));
  210. #undef VEC_INTRIN
  211. #undef RSQRT_INTRIN
  212. #undef CMPEQ_INTRIN
  213. #undef ANDNOT_INTRIN
  214. }
  215. template <> inline __m256d rsqrt_approx_intrin(const __m256d& r2) {
  216. #define PD2PS(a) _mm256_cvtpd_ps(a)
  217. #define PS2PD(a) _mm256_cvtps_pd(a)
  218. return PS2PD(rsqrt_approx_intrin(PD2PS(r2)));
  219. #undef PD2PS
  220. #undef PS2PD
  221. }
  222. template <> inline void rsqrt_newton_intrin(__m256& rinv, const __m256& r2, const float& nwtn_const) {
  223. #define VEC_INTRIN __m256
  224. // Newton iteration: rinv = 0.5 rinv_approx ( 3 - r2 rinv_approx^2 )
  225. // We do not compute the product with 0.5 and this needs to be adjusted later
  226. rinv = mul_intrin(rinv, sub_intrin(set_intrin<VEC_INTRIN>(nwtn_const), mul_intrin(r2, mul_intrin(rinv, rinv))));
  227. #undef VEC_INTRIN
  228. }
  229. template <> inline void rsqrt_newton_intrin(__m256d& rinv, const __m256d& r2, const double& nwtn_const) {
  230. #define VEC_INTRIN __m256d
  231. // Newton iteration: rinv = 0.5 rinv_approx ( 3 - r2 rinv_approx^2 )
  232. // We do not compute the product with 0.5 and this needs to be adjusted later
  233. rinv = mul_intrin(rinv, sub_intrin(set_intrin<VEC_INTRIN>(nwtn_const), mul_intrin(r2, mul_intrin(rinv, rinv))));
  234. #undef VEC_INTRIN
  235. }
  236. template <> inline __m256 rsqrt_single_intrin(const __m256& r2) {
  237. #define VEC_INTRIN __m256
  238. VEC_INTRIN rinv = rsqrt_approx_intrin(r2);
  239. rsqrt_newton_intrin(rinv, r2, (float)3.0);
  240. return rinv;
  241. #undef VEC_INTRIN
  242. }
  243. template <> inline __m256d rsqrt_single_intrin(const __m256d& r2) {
  244. #define PD2PS(a) _mm256_cvtpd_ps(a)
  245. #define PS2PD(a) _mm256_cvtps_pd(a)
  246. return PS2PD(rsqrt_single_intrin(PD2PS(r2)));
  247. #undef PD2PS
  248. #undef PS2PD
  249. }
  250. template <> inline __m256 max_intrin(const __m256& a, const __m256& b) { return _mm256_max_ps(a, b); }
  251. template <> inline __m256d max_intrin(const __m256d& a, const __m256d& b) { return _mm256_max_pd(a, b); }
  252. template <> inline __m256 min_intrin(const __m256& a, const __m256& b) { return _mm256_min_ps(a, b); }
  253. template <> inline __m256d min_intrin(const __m256d& a, const __m256d& b) { return _mm256_min_pd(a, b); }
  254. #ifdef SCTL_HAVE_INTEL_SVML
  255. template <> inline __m256 sin_intrin(const __m256& t) { return _mm256_sin_ps(t); }
  256. template <> inline __m256 cos_intrin(const __m256& t) { return _mm256_cos_ps(t); }
  257. template <> inline __m256d sin_intrin(const __m256d& t) { return _mm256_sin_pd(t); }
  258. template <> inline __m256d cos_intrin(const __m256d& t) { return _mm256_cos_pd(t); }
  259. #else
  260. template <> inline __m256 sin_intrin(const __m256& t_) {
  261. union {
  262. float e[8];
  263. __m256 d;
  264. } t;
  265. store_intrin(t.e, t_); // t.d=t_;
  266. return _mm256_set_ps(sin<float>(t.e[7]), sin<float>(t.e[6]), sin<float>(t.e[5]), sin<float>(t.e[4]), sin<float>(t.e[3]), sin<float>(t.e[2]), sin<float>(t.e[1]), sin<float>(t.e[0]));
  267. }
  268. template <> inline __m256 cos_intrin(const __m256& t_) {
  269. union {
  270. float e[8];
  271. __m256 d;
  272. } t;
  273. store_intrin(t.e, t_); // t.d=t_;
  274. return _mm256_set_ps(cos<float>(t.e[7]), cos<float>(t.e[6]), cos<float>(t.e[5]), cos<float>(t.e[4]), cos<float>(t.e[3]), cos<float>(t.e[2]), cos<float>(t.e[1]), cos<float>(t.e[0]));
  275. }
  276. template <> inline __m256d sin_intrin(const __m256d& t_) {
  277. union {
  278. double e[4];
  279. __m256d d;
  280. } t;
  281. store_intrin(t.e, t_); // t.d=t_;
  282. return _mm256_set_pd(sin<double>(t.e[3]), sin<double>(t.e[2]), sin<double>(t.e[1]), sin<double>(t.e[0]));
  283. }
  284. template <> inline __m256d cos_intrin(const __m256d& t_) {
  285. union {
  286. double e[4];
  287. __m256d d;
  288. } t;
  289. store_intrin(t.e, t_); // t.d=t_;
  290. return _mm256_set_pd(cos<double>(t.e[3]), cos<double>(t.e[2]), cos<double>(t.e[1]), cos<double>(t.e[0]));
  291. }
  292. #endif
  293. #endif
  294. template <class VEC, class Real> inline VEC rsqrt_intrin0(VEC r2) {
  295. #define NWTN0 0
  296. #define NWTN1 0
  297. #define NWTN2 0
  298. #define NWTN3 0
  299. // Real scal=1; Real const_nwtn0=3*scal*scal;
  300. // scal=(NWTN0?2*scal*scal*scal:scal); Real const_nwtn1=3*scal*scal;
  301. // scal=(NWTN1?2*scal*scal*scal:scal); Real const_nwtn2=3*scal*scal;
  302. // scal=(NWTN2?2*scal*scal*scal:scal); Real const_nwtn3=3*scal*scal;
  303. VEC rinv;
  304. #if NWTN0
  305. rinv = rsqrt_single_intrin(r2);
  306. #else
  307. rinv = rsqrt_approx_intrin(r2);
  308. #endif
  309. #if NWTN1
  310. rsqrt_newton_intrin(rinv, r2, const_nwtn1);
  311. #endif
  312. #if NWTN2
  313. rsqrt_newton_intrin(rinv, r2, const_nwtn2);
  314. #endif
  315. #if NWTN3
  316. rsqrt_newton_intrin(rinv, r2, const_nwtn3);
  317. #endif
  318. return rinv;
  319. #undef NWTN0
  320. #undef NWTN1
  321. #undef NWTN2
  322. #undef NWTN3
  323. }
  324. template <class VEC, class Real> inline VEC rsqrt_intrin1(VEC r2) {
  325. #define NWTN0 0
  326. #define NWTN1 1
  327. #define NWTN2 0
  328. #define NWTN3 0
  329. Real scal = 1; // Real const_nwtn0=3*scal*scal;
  330. scal = (NWTN0 ? 2 * scal * scal * scal : scal);
  331. Real const_nwtn1 = 3 * scal * scal;
  332. // scal=(NWTN1?2*scal*scal*scal:scal); Real const_nwtn2=3*scal*scal;
  333. // scal=(NWTN2?2*scal*scal*scal:scal); Real const_nwtn3=3*scal*scal;
  334. VEC rinv;
  335. #if NWTN0
  336. rinv = rsqrt_single_intrin(r2);
  337. #else
  338. rinv = rsqrt_approx_intrin(r2);
  339. #endif
  340. #if NWTN1
  341. rsqrt_newton_intrin(rinv, r2, const_nwtn1);
  342. #endif
  343. #if NWTN2
  344. rsqrt_newton_intrin(rinv, r2, const_nwtn2);
  345. #endif
  346. #if NWTN3
  347. rsqrt_newton_intrin(rinv, r2, const_nwtn3);
  348. #endif
  349. return rinv;
  350. #undef NWTN0
  351. #undef NWTN1
  352. #undef NWTN2
  353. #undef NWTN3
  354. }
  355. template <class VEC, class Real> inline VEC rsqrt_intrin2(VEC r2) {
  356. #define NWTN0 0
  357. #define NWTN1 1
  358. #define NWTN2 1
  359. #define NWTN3 0
  360. Real scal = 1; // Real const_nwtn0=3*scal*scal;
  361. scal = (NWTN0 ? 2 * scal * scal * scal : scal);
  362. Real const_nwtn1 = 3 * scal * scal;
  363. scal = (NWTN1 ? 2 * scal * scal * scal : scal);
  364. Real const_nwtn2 = 3 * scal * scal;
  365. // scal=(NWTN2?2*scal*scal*scal:scal); Real const_nwtn3=3*scal*scal;
  366. VEC rinv;
  367. #if NWTN0
  368. rinv = rsqrt_single_intrin(r2);
  369. #else
  370. rinv = rsqrt_approx_intrin(r2);
  371. #endif
  372. #if NWTN1
  373. rsqrt_newton_intrin(rinv, r2, const_nwtn1);
  374. #endif
  375. #if NWTN2
  376. rsqrt_newton_intrin(rinv, r2, const_nwtn2);
  377. #endif
  378. #if NWTN3
  379. rsqrt_newton_intrin(rinv, r2, const_nwtn3);
  380. #endif
  381. return rinv;
  382. #undef NWTN0
  383. #undef NWTN1
  384. #undef NWTN2
  385. #undef NWTN3
  386. }
  387. template <class VEC, class Real> inline VEC rsqrt_intrin3(VEC r2) {
  388. #define NWTN0 0
  389. #define NWTN1 1
  390. #define NWTN2 1
  391. #define NWTN3 1
  392. Real scal = 1; // Real const_nwtn0=3*scal*scal;
  393. scal = (NWTN0 ? 2 * scal * scal * scal : scal);
  394. Real const_nwtn1 = 3 * scal * scal;
  395. scal = (NWTN1 ? 2 * scal * scal * scal : scal);
  396. Real const_nwtn2 = 3 * scal * scal;
  397. scal = (NWTN2 ? 2 * scal * scal * scal : scal);
  398. Real const_nwtn3 = 3 * scal * scal;
  399. VEC rinv;
  400. #if NWTN0
  401. rinv = rsqrt_single_intrin(r2);
  402. #else
  403. rinv = rsqrt_approx_intrin(r2);
  404. #endif
  405. #if NWTN1
  406. rsqrt_newton_intrin(rinv, r2, const_nwtn1);
  407. #endif
  408. #if NWTN2
  409. rsqrt_newton_intrin(rinv, r2, const_nwtn2);
  410. #endif
  411. #if NWTN3
  412. rsqrt_newton_intrin(rinv, r2, const_nwtn3);
  413. #endif
  414. return rinv;
  415. #undef NWTN0
  416. #undef NWTN1
  417. #undef NWTN2
  418. #undef NWTN3
  419. }
  420. }
  421. #endif //_SCTL_INTRIN_WRAPPER_HPP_