vec.hpp 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071
  1. #ifndef _SCTL_VEC_WRAPPER_HPP_
  2. #define _SCTL_VEC_WRAPPER_HPP_
  3. #include SCTL_INCLUDE(math_utils.hpp)
  4. #include SCTL_INCLUDE(common.hpp)
  5. #include <cstdint>
  6. #include <ostream>
  7. #ifdef __SSE__
  8. #include <xmmintrin.h>
  9. #endif
  10. #ifdef __SSE2__
  11. #include <emmintrin.h>
  12. #endif
  13. #ifdef __SSE3__
  14. #include <pmmintrin.h>
  15. #endif
  16. #ifdef __SSE4_2__
  17. #include <smmintrin.h>
  18. #endif
  19. #ifdef __AVX__
  20. #include <immintrin.h>
  21. #endif
  22. #if defined(__MIC__)
  23. #include <immintrin.h>
  24. #endif
  25. // TODO: Implement AVX versions of floats, int32_t, int64_t
  26. // TODO: Check alignment when SCTL_MEMDEBUG is defined
  27. // TODO: Replace pointers with iterators
  28. namespace SCTL_NAMESPACE {
  29. enum class DataType {
  30. Integer,
  31. Real,
  32. Bool
  33. };
  34. template <class ValueType> class TypeTraits {
  35. public:
  36. static constexpr DataType Type = DataType::Bool;
  37. static constexpr Integer Size = sizeof(ValueType);
  38. static constexpr Integer SigBits = 1;
  39. };
  40. template <> class TypeTraits<int32_t> {
  41. public:
  42. static constexpr DataType Type = DataType::Integer;
  43. static constexpr Integer Size = sizeof(int32_t);
  44. static constexpr Integer SigBits = Size * 8;
  45. };
  46. template <> class TypeTraits<int64_t> {
  47. public:
  48. static constexpr DataType Type = DataType::Integer;
  49. static constexpr Integer Size = sizeof(int64_t);
  50. static constexpr Integer SigBits = Size * 8;
  51. };
  52. template <> class TypeTraits<float> {
  53. public:
  54. static constexpr DataType Type = DataType::Real;
  55. static constexpr Integer Size = sizeof(float);
  56. static constexpr Integer SigBits = 23;
  57. };
  58. template <> class TypeTraits<double> {
  59. public:
  60. static constexpr DataType Type = DataType::Real;
  61. static constexpr Integer Size = sizeof(double);
  62. static constexpr Integer SigBits = 52;
  63. };
  64. template <DataType type, Integer size> class GetType {
  65. public:
  66. typedef bool ValueType;
  67. };
  68. template <> class GetType<DataType::Integer,4> {
  69. public:
  70. typedef int32_t ValueType;
  71. };
  72. template <> class GetType<DataType::Integer,8> {
  73. public:
  74. typedef int64_t ValueType;
  75. };
  76. template <> class GetType<DataType::Real,4> {
  77. public:
  78. typedef float ValueType;
  79. };
  80. template <> class GetType<DataType::Real,8> {
  81. public:
  82. typedef double ValueType;
  83. };
  84. template <class ValueType, Integer N> class alignas(sizeof(ValueType) * N) Vec {
  85. public:
  86. typedef typename GetType<DataType::Integer,TypeTraits<ValueType>::Size>::ValueType IntegerType;
  87. typedef typename GetType<DataType::Real,TypeTraits<ValueType>::Size>::ValueType RealType;
  88. typedef Vec<IntegerType,N> IntegerVec;
  89. typedef Vec<RealType,N> RealVec;
  90. typedef ValueType ScalarType;
  91. static constexpr Integer Size() {
  92. return N;
  93. }
  94. static Vec Zero() {
  95. Vec r;
  96. for (Integer i = 0; i < N; i++) r.v[i] = 0;
  97. return r;
  98. }
  99. static Vec Load1(ValueType const* p) {
  100. Vec r;
  101. for (Integer i = 0; i < N; i++) r.v[i] = p[0];
  102. return r;
  103. }
  104. static Vec Load(ValueType const* p) {
  105. Vec r;
  106. for (Integer i = 0; i < N; i++) r.v[i] = p[i];
  107. return r;
  108. }
  109. static Vec LoadAligned(ValueType const* p) {
  110. Vec r;
  111. for (Integer i = 0; i < N; i++) r.v[i] = p[i];
  112. return r;
  113. }
  114. Vec() = default;
  115. Vec(const ValueType& a) {
  116. for (Integer i = 0; i < N; i++) v[i] = a;
  117. }
  118. void Store(ValueType* p) const {
  119. for (Integer i = 0; i < N; i++) p[i] = v[i];
  120. }
  121. void StoreAligned(ValueType* p) const {
  122. for (Integer i = 0; i < N; i++) p[i] = v[i];
  123. }
  124. // Bitwise NOT
  125. Vec operator~() const {
  126. Vec r;
  127. char* vo = (char*)r.v;
  128. const char* vi = (const char*)this->v;
  129. for (Integer i = 0; i < (Integer)(N*sizeof(ValueType)); i++) vo[i] = ~vi[i];
  130. return r;
  131. }
  132. // Unary plus and minus
  133. Vec operator+() const {
  134. return *this;
  135. }
  136. Vec operator-() const {
  137. Vec r;
  138. for (Integer i = 0; i < N; i++) r.v[i] = -v[i];
  139. return r;
  140. }
  141. // C-style cast
  142. //template <class RetValueType> explicit operator Vec<RetValueType,N>() const {
  143. // Vec<RetValueType,N> r;
  144. // for (Integer i = 0; i < N; i++) r.v[i] = (RetValueType)v[i];
  145. // return r;
  146. //}
  147. // Arithmetic operators
  148. friend Vec operator*(Vec lhs, const Vec& rhs) {
  149. for (Integer i = 0; i < N; i++) lhs.v[i] *= rhs.v[i];
  150. return lhs;
  151. }
  152. friend Vec operator+(Vec lhs, const Vec& rhs) {
  153. for (Integer i = 0; i < N; i++) lhs.v[i] += rhs.v[i];
  154. return lhs;
  155. }
  156. friend Vec operator-(Vec lhs, const Vec& rhs) {
  157. for (Integer i = 0; i < N; i++) lhs.v[i] -= rhs.v[i];
  158. return lhs;
  159. }
  160. friend Vec FMA(Vec a, const Vec& b, const Vec& c) {
  161. for (Integer i = 0; i < N; i++) a.v[i] = a.v[i] * b.v[i] + c.v[i];
  162. return a;
  163. }
  164. // Comparison operators
  165. friend Vec operator< (Vec lhs, const Vec& rhs) {
  166. static const ValueType value_zero = const_zero();
  167. static const ValueType value_one = const_one();
  168. for (Integer i = 0; i < N; i++) lhs.v[i] = (lhs.v[i] < rhs.v[i] ? value_one : value_zero);
  169. return lhs;
  170. }
  171. friend Vec operator<=(Vec lhs, const Vec& rhs) {
  172. static const ValueType value_zero = const_zero();
  173. static const ValueType value_one = const_one();
  174. for (Integer i = 0; i < N; i++) lhs.v[i] = (lhs.v[i] <= rhs.v[i] ? value_one : value_zero);
  175. return lhs;
  176. }
  177. friend Vec operator>=(Vec lhs, const Vec& rhs) {
  178. static const ValueType value_zero = const_zero();
  179. static const ValueType value_one = const_one();
  180. for (Integer i = 0; i < N; i++) lhs.v[i] = (lhs.v[i] >= rhs.v[i] ? value_one : value_zero);
  181. return lhs;
  182. }
  183. friend Vec operator> (Vec lhs, const Vec& rhs) {
  184. static const ValueType value_zero = const_zero();
  185. static const ValueType value_one = const_one();
  186. for (Integer i = 0; i < N; i++) lhs.v[i] = (lhs.v[i] > rhs.v[i] ? value_one : value_zero);
  187. return lhs;
  188. }
  189. friend Vec operator==(Vec lhs, const Vec& rhs) {
  190. static const ValueType value_zero = const_zero();
  191. static const ValueType value_one = const_one();
  192. for (Integer i = 0; i < N; i++) lhs.v[i] = (lhs.v[i] == rhs.v[i] ? value_one : value_zero);
  193. return lhs;
  194. }
  195. friend Vec operator!=(Vec lhs, const Vec& rhs) {
  196. static const ValueType value_zero = const_zero();
  197. static const ValueType value_one = const_one();
  198. for (Integer i = 0; i < N; i++) lhs.v[i] = (lhs.v[i] != rhs.v[i] ? value_one : value_zero);
  199. return lhs;
  200. }
  201. // Bitwise operators
  202. friend Vec operator&(Vec lhs, const Vec& rhs) {
  203. char* vo = (char*)lhs.v;
  204. const char* vi = (const char*)rhs.v;
  205. for (Integer i = 0; i < (Integer)sizeof(ValueType)*N; i++) vo[i] &= vi[i];
  206. return lhs;
  207. }
  208. friend Vec operator^(Vec lhs, const Vec& rhs) {
  209. char* vo = (char*)lhs.v;
  210. const char* vi = (const char*)rhs.v;
  211. for (Integer i = 0; i < (Integer)sizeof(ValueType)*N; i++) vo[i] ^= vi[i];
  212. return lhs;
  213. }
  214. friend Vec operator|(Vec lhs, const Vec& rhs) {
  215. char* vo = (char*)lhs.v;
  216. const char* vi = (const char*)rhs.v;
  217. for (Integer i = 0; i < (Integer)sizeof(ValueType)*N; i++) vo[i] |= vi[i];
  218. return lhs;
  219. }
  220. friend Vec AndNot(Vec lhs, const Vec& rhs) {
  221. return lhs & (~rhs);
  222. }
  223. // Bitshift
  224. friend IntegerVec operator<<(const Vec& lhs, const Integer& rhs) {
  225. IntegerVec r = IntegerVec::LoadAligned(&lhs.v[0]);
  226. for (Integer i = 0; i < N; i++) r.v[i] = r.v[i] << rhs;
  227. return r;
  228. }
  229. // Assignment operators
  230. Vec& operator+=(const Vec& rhs) {
  231. for (Integer i = 0; i < N; i++) v[i] += rhs.v[i];
  232. return *this;
  233. }
  234. Vec& operator-=(const Vec& rhs) {
  235. for (Integer i = 0; i < N; i++) v[i] -= rhs.v[i];
  236. return *this;
  237. }
  238. Vec& operator*=(const Vec& rhs) {
  239. for (Integer i = 0; i < N; i++) v[i] *= rhs.v[i];
  240. return *this;
  241. }
  242. Vec& operator&=(const Vec& rhs) {
  243. char* vo = (char*)this->v;
  244. const char* vi = (const char*)rhs.v;
  245. for (Integer i = 0; i < (Integer)sizeof(ValueType)*N; i++) vo[i] &= vi[i];
  246. return *this;
  247. }
  248. Vec& operator^=(const Vec& rhs) {
  249. char* vo = (char*)this->v;
  250. const char* vi = (const char*)rhs.v;
  251. for (Integer i = 0; i < (Integer)sizeof(ValueType)*N; i++) vo[i] ^= vi[i];
  252. return *this;
  253. }
  254. Vec& operator|=(const Vec& rhs) {
  255. char* vo = (char*)this->v;
  256. const char* vi = (const char*)rhs.v;
  257. for (Integer i = 0; i < (Integer)sizeof(ValueType)*N; i++) vo[i] |= vi[i];
  258. return *this;
  259. }
  260. // Conversion operators
  261. // /
  262. // Other operators
  263. friend Vec max(Vec lhs, const Vec& rhs) {
  264. for (Integer i = 0; i < N; i++) {
  265. if (lhs.v[i] < rhs.v[i]) lhs.v[i] = rhs.v[i];
  266. }
  267. return lhs;
  268. }
  269. friend Vec min(Vec lhs, const Vec& rhs) {
  270. for (Integer i = 0; i < N; i++) {
  271. if (lhs.v[i] > rhs.v[i]) lhs.v[i] = rhs.v[i];
  272. }
  273. return lhs;
  274. }
  275. friend std::ostream& operator<<(std::ostream& os, const Vec& in) {
  276. //for (Integer i = 0; i < (Integer)sizeof(ValueType)*8; i++) os << ((*(uint64_t*)in.v) & (1UL << i) ? '1' : '0');
  277. //os << '\n';
  278. for (Integer i = 0; i < N; i++) os << in.v[i] << ' ';
  279. return os;
  280. }
  281. friend Vec approx_rsqrt(const Vec& x) {
  282. Vec r;
  283. for (int i = 0; i < N; i++) r.v[i] = 1 / sqrt<ValueType>(x.v[i]);
  284. return r;
  285. }
  286. template <class Vec1, class Vec2> friend Vec1 reinterpret(const Vec2& x);
  287. private:
  288. static const ValueType const_zero() {
  289. union {
  290. ValueType value;
  291. unsigned char cvalue[sizeof(ValueType)];
  292. };
  293. for (Integer i = 0; i < (Integer)sizeof(ValueType); i++) cvalue[i] = 0;
  294. return value;
  295. }
  296. static const ValueType const_one() {
  297. union {
  298. ValueType value;
  299. unsigned char cvalue[sizeof(ValueType)];
  300. };
  301. for (Integer i = 0; i < (Integer)sizeof(ValueType); i++) cvalue[i] = ~(unsigned char)0;
  302. return value;
  303. }
  304. ValueType v[N];
  305. };
  306. // Other operators
  307. template <class RetVec, class Vec> RetVec reinterpret(const Vec& v){
  308. static_assert(sizeof(RetVec) == sizeof(Vec));
  309. RetVec& r = *(RetVec*)&v;
  310. return r;
  311. }
  312. template <class RealVec, class IntVec> RealVec ConvertInt2Real(const IntVec& x) {
  313. typedef typename RealVec::ScalarType Real;
  314. typedef typename IntVec::ScalarType Int;
  315. assert(sizeof(RealVec) == sizeof(IntVec));
  316. assert(sizeof(Real) == sizeof(Int));
  317. static constexpr Integer SigBits = TypeTraits<Real>::SigBits;
  318. union {
  319. Int Cint = (1UL << (SigBits - 1)) + ((SigBits + ((1UL<<(sizeof(Real)*8 - SigBits - 2))-1)) << SigBits);
  320. Real Creal;
  321. };
  322. IntVec l(x + IntVec(Cint));
  323. return *(RealVec*)&l - RealVec(Creal);
  324. }
  325. template <class Vec> typename Vec::IntegerVec RoundReal2Int(const Vec& x) {
  326. using IntegerType = typename Vec::IntegerType;
  327. using RealType = typename Vec::RealType;
  328. using IntegerVec = typename Vec::IntegerVec;
  329. using RealVec = typename Vec::RealVec;
  330. static_assert(std::is_same<RealVec,Vec>::value, "RoundReal2Int: expected real input argument!");
  331. static constexpr Integer SigBits = TypeTraits<RealType>::SigBits;
  332. union {
  333. IntegerType Cint = (1UL << (SigBits - 1)) + ((SigBits + ((1UL<<(sizeof(RealType)*8 - SigBits - 2))-1)) << SigBits);
  334. RealType Creal;
  335. };
  336. RealVec d = x + RealVec(Creal);
  337. return reinterpret<IntegerVec>(d) - IntegerVec(Cint);
  338. }
  339. template <class Vec> Vec RoundReal2Real(const Vec& x) {
  340. typedef typename Vec::ScalarType Real;
  341. static constexpr Integer SigBits = TypeTraits<Real>::SigBits;
  342. union {
  343. int64_t Cint = (1UL << (SigBits - 1)) + ((SigBits + ((1UL<<(sizeof(Real)*8 - SigBits - 2))-1)) << SigBits);
  344. Real Creal;
  345. };
  346. Vec Vreal(Creal);
  347. return (x + Vreal) - Vreal;
  348. }
  349. template <class Vec> void sincos_intrin(Vec& sinx, Vec& cosx, const Vec& x) {
  350. constexpr Integer ORDER = 13;
  351. // ORDER ERROR
  352. // 1 8.81e-02
  353. // 3 2.45e-03
  354. // 5 3.63e-05
  355. // 7 3.11e-07
  356. // 9 1.75e-09
  357. // 11 6.93e-12
  358. // 13 2.09e-14
  359. // 15 6.66e-16
  360. // 17 6.66e-16
  361. using Real = typename Vec::ScalarType;
  362. static constexpr Integer SigBits = TypeTraits<Real>::SigBits;
  363. static constexpr Real coeff3 = -1/(((Real)2)*3);
  364. static constexpr Real coeff5 = 1/(((Real)2)*3*4*5);
  365. static constexpr Real coeff7 = -1/(((Real)2)*3*4*5*6*7);
  366. static constexpr Real coeff9 = 1/(((Real)2)*3*4*5*6*7*8*9);
  367. static constexpr Real coeff11 = -1/(((Real)2)*3*4*5*6*7*8*9*10*11);
  368. static constexpr Real coeff13 = 1/(((Real)2)*3*4*5*6*7*8*9*10*11*12*13);
  369. static constexpr Real coeff15 = -1/(((Real)2)*3*4*5*6*7*8*9*10*11*12*13*14*15);
  370. static constexpr Real coeff17 = 1/(((Real)2)*3*4*5*6*7*8*9*10*11*12*13*14*15*16*17);
  371. static constexpr Real coeff19 = -1/(((Real)2)*3*4*5*6*7*8*9*10*11*12*13*14*15*16*17*18*19);
  372. static constexpr Real x0 = (Real)1.570796326794896619231321691639l;
  373. static constexpr Real invx0 = 1 / x0;
  374. Vec x_ = RoundReal2Real(x * invx0); // 4.5 - cycles
  375. Vec x1 = x - x_ * x0; // 2 - cycles
  376. Vec x2, x3, x5, x7, x9, x11, x13, x15, x17, x19;
  377. Vec s1 = x1;
  378. if (ORDER >= 3) { // 5 - cycles
  379. x2 = x1 * x1;
  380. x3 = x1 * x2;
  381. s1 += x3 * coeff3;
  382. }
  383. if (ORDER >= 5) { // 3 - cycles
  384. x5 = x3 * x2;
  385. s1 += x5 * coeff5;
  386. }
  387. if (ORDER >= 7) {
  388. x7 = x5 * x2;
  389. s1 += x7 * coeff7;
  390. }
  391. if (ORDER >= 9) {
  392. x9 = x7 * x2;
  393. s1 += x9 * coeff9;
  394. }
  395. if (ORDER >= 11) {
  396. x11 = x9 * x2;
  397. s1 += x11 * coeff11;
  398. }
  399. if (ORDER >= 13) {
  400. x13 = x11 * x2;
  401. s1 += x13 * coeff13;
  402. }
  403. if (ORDER >= 15) {
  404. x15 = x13 * x2;
  405. s1 += x15 * coeff15;
  406. }
  407. if (ORDER >= 17) {
  408. x17 = x15 * x2;
  409. s1 += x17 * coeff17;
  410. }
  411. if (ORDER >= 19) {
  412. x19 = x17 * x2;
  413. s1 += x19 * coeff19;
  414. }
  415. Vec cos_squared = (Real)1.0 - s1 * s1;
  416. Vec inv_cos = approx_rsqrt(cos_squared); // 1.5 - cycles
  417. if (ORDER < 5) {
  418. } else if (ORDER < 9) {
  419. inv_cos *= ((3.0) - cos_squared * inv_cos * inv_cos) * 0.5; // 7 - cycles
  420. } else if (ORDER < 15) {
  421. inv_cos *= ((3.0) - cos_squared * inv_cos * inv_cos); // 7 - cycles
  422. inv_cos *= ((3.0 * pow<pow<0>(3)*3-1>(2.0)) - cos_squared * inv_cos * inv_cos) * (pow<(pow<0>(3)*3-1)*3/2+1>(0.5)); // 8 - cycles
  423. } else {
  424. inv_cos *= ((3.0) - cos_squared * inv_cos * inv_cos); // 7 - cycles
  425. inv_cos *= ((3.0 * pow<pow<0>(3)*3-1>(2.0)) - cos_squared * inv_cos * inv_cos); // 7 - cycles
  426. inv_cos *= ((3.0 * pow<pow<1>(3)*3-1>(2.0)) - cos_squared * inv_cos * inv_cos) * (pow<(pow<1>(3)*3-1)*3/2+1>(0.5)); // 8 - cycles
  427. }
  428. Vec c1 = cos_squared * inv_cos; // 1 - cycle
  429. union {
  430. int64_t int_zero = 0 + (1UL << (SigBits - 1)) + ((SigBits + ((1UL<<(sizeof(Real)*8 - SigBits - 2))-1)) << SigBits);
  431. Real real_zero;
  432. };
  433. union {
  434. int64_t int_one = 1 + (1UL << (SigBits - 1)) + ((SigBits + ((1UL<<(sizeof(Real)*8 - SigBits - 2))-1)) << SigBits);
  435. Real real_one;
  436. };
  437. union {
  438. int64_t int_two = 2 + (1UL << (SigBits - 1)) + ((SigBits + ((1UL<<(sizeof(Real)*8 - SigBits - 2))-1)) << SigBits);
  439. Real real_two;
  440. };
  441. Vec x_offset(real_zero);
  442. auto xAnd1 = (((x_+x_offset) & Vec(real_one)) == x_offset);
  443. auto xAnd2 = (((x_+x_offset) & Vec(real_two)) == x_offset);
  444. Vec s2 = AndNot( c1,xAnd1) | (s1 & xAnd1);
  445. Vec c2 = AndNot(-s1,xAnd1) | (c1 & xAnd1);
  446. Vec s3 = AndNot(-s2,xAnd2) | (s2 & xAnd2);
  447. Vec c3 = AndNot(-c2,xAnd2) | (c2 & xAnd2);
  448. sinx = s3;
  449. cosx = c3;
  450. }
  451. template <class Vec> void exp_intrin(Vec& expx, const Vec& x) {
  452. constexpr Integer ORDER = 10;
  453. using IntegerType = typename Vec::IntegerType;
  454. using RealType = typename Vec::RealType;
  455. using IntegerVec = typename Vec::IntegerVec;
  456. using RealVec = typename Vec::RealVec;
  457. static_assert(std::is_same<Vec,RealVec>::value, "exp_intrin: expected a real argument");
  458. using Real = typename RealVec::ScalarType;
  459. static constexpr Integer SigBits = TypeTraits<Real>::SigBits;
  460. static constexpr Real coeff2 = 1/(((Real)2));
  461. static constexpr Real coeff3 = 1/(((Real)2)*3);
  462. static constexpr Real coeff4 = 1/(((Real)2)*3*4);
  463. static constexpr Real coeff5 = 1/(((Real)2)*3*4*5);
  464. static constexpr Real coeff6 = 1/(((Real)2)*3*4*5*6);
  465. static constexpr Real coeff7 = 1/(((Real)2)*3*4*5*6*7);
  466. static constexpr Real coeff8 = 1/(((Real)2)*3*4*5*6*7*8);
  467. static constexpr Real coeff9 = 1/(((Real)2)*3*4*5*6*7*8*9);
  468. static constexpr Real coeff10 = 1/(((Real)2)*3*4*5*6*7*8*9*10);
  469. static constexpr Real x0 = (Real)0.693147180559945309417232121458l; // ln(2)
  470. static constexpr Real invx0 = 1 / x0;
  471. RealVec x_ = RoundReal2Real(x * invx0); // 4.5 - cycles
  472. IntegerVec int_x_ = RoundReal2Int<RealVec>(x_);
  473. RealVec x1 = x - x_ * x0; // 2 - cycles
  474. RealVec x2, x3, x4, x5, x6, x7, x8, x9, x10;
  475. RealVec e1 = 1.0 + x1;
  476. if (ORDER >= 2) {
  477. x2 = x1 * x1;
  478. e1 += x2 * coeff2;
  479. }
  480. if (ORDER >= 3) {
  481. x3 = x2 * x1;
  482. e1 += x3 * coeff3;
  483. }
  484. if (ORDER >= 4) {
  485. x4 = x2 * x2;
  486. e1 += x4 * coeff4;
  487. }
  488. if (ORDER >= 5) {
  489. x5 = x3 * x2;
  490. e1 += x5 * coeff5;
  491. }
  492. if (ORDER >= 6) {
  493. x6 = x3 * x3;
  494. e1 += x6 * coeff6;
  495. }
  496. if (ORDER >= 7) {
  497. x7 = x4 * x3;
  498. e1 += x7 * coeff7;
  499. }
  500. if (ORDER >= 8) {
  501. x8 = x4 * x4;
  502. e1 += x8 * coeff8;
  503. }
  504. if (ORDER >= 9) {
  505. x9 = x5 * x4;
  506. e1 += x9 * coeff9;
  507. }
  508. if (ORDER >= 10) {
  509. x10 = x5 * x5;
  510. e1 += x10 * coeff10;
  511. }
  512. RealVec e2;
  513. { // set e2 = 2 ^ x_
  514. union {
  515. RealType real_one = 1.0;
  516. IntegerType int_one;
  517. };
  518. //__m256i int_e2 = _mm256_add_epi64(
  519. // _mm256_set1_epi64x(int_one),
  520. // _mm256_slli_epi64(
  521. // _mm256_load_si256((__m256i const*)&int_x_),
  522. // SigBits
  523. // )
  524. // ); // int_e2 = int_one + (int_x_ << SigBits);
  525. IntegerVec int_e2 = IntegerVec(int_one) + (int_x_ << SigBits);
  526. // Handle underflow
  527. static constexpr IntegerType max_exp = -(IntegerType)(1UL<<((sizeof(Real)*8-SigBits-2)));
  528. int_e2 &= (int_x_ > IntegerVec(max_exp));
  529. e2 = RealVec::LoadAligned((RealType*)&int_e2);
  530. }
  531. expx = e1 * e2;
  532. }
  533. #ifdef __AVX__
  534. template <> class alignas(sizeof(double)*4) Vec<double,4> {
  535. typedef __m256d VecType;
  536. typedef double ValueType;
  537. static constexpr Integer N = 4;
  538. public:
  539. typedef typename GetType<DataType::Integer,TypeTraits<ValueType>::Size>::ValueType IntegerType;
  540. typedef typename GetType<DataType::Real,TypeTraits<ValueType>::Size>::ValueType RealType;
  541. typedef Vec<IntegerType,N> IntegerVec;
  542. typedef Vec<RealType,N> RealVec;
  543. typedef ValueType ScalarType;
  544. static constexpr Integer Size() {
  545. return N;
  546. }
  547. static Vec Zero() {
  548. Vec r;
  549. r.v = _mm256_setzero_pd();
  550. return r;
  551. }
  552. static Vec Load1(ValueType const* p) {
  553. Vec r;
  554. r.v = _mm256_broadcast_sd(p);
  555. return r;
  556. }
  557. static Vec Load(ValueType const* p) {
  558. Vec r;
  559. r.v = _mm256_loadu_pd(p);
  560. return r;
  561. }
  562. static Vec LoadAligned(ValueType const* p) {
  563. Vec r;
  564. r.v = _mm256_load_pd(p);
  565. return r;
  566. }
  567. Vec() = default;
  568. Vec(const ValueType& a) {
  569. v = _mm256_set1_pd(a);
  570. }
  571. void Store(ValueType* p) const {
  572. _mm256_storeu_pd(p, v);
  573. }
  574. void StoreAligned(ValueType* p) const {
  575. _mm256_store_pd(p, v);
  576. }
  577. // Bitwise NOT
  578. Vec operator~() const {
  579. Vec r;
  580. static constexpr ScalarType Creal = -1.0;
  581. r.v = _mm256_xor_pd(v, _mm256_set1_pd(Creal));
  582. return r;
  583. }
  584. // Unary plus and minus
  585. Vec operator+() const {
  586. return *this;
  587. }
  588. Vec operator-() const {
  589. return Zero() - (*this);
  590. }
  591. // C-style cast
  592. //template <class RetValueType> explicit operator Vec<RetValueType,N>() const {
  593. //}
  594. // Arithmetic operators
  595. friend Vec operator*(Vec lhs, const Vec& rhs) {
  596. lhs.v = _mm256_mul_pd(lhs.v, rhs.v);
  597. return lhs;
  598. }
  599. friend Vec operator+(Vec lhs, const Vec& rhs) {
  600. lhs.v = _mm256_add_pd(lhs.v, rhs.v);
  601. return lhs;
  602. }
  603. friend Vec operator-(Vec lhs, const Vec& rhs) {
  604. lhs.v = _mm256_sub_pd(lhs.v, rhs.v);
  605. return lhs;
  606. }
  607. friend Vec FMA(Vec a, const Vec& b, const Vec& c) {
  608. #ifdef __FMA__
  609. a.v = _mm256_fmadd_pd(a.v, b.v, c.v);
  610. #else
  611. a.v = _mm256_add_pd(_mm256_mul_pd(a.v, b.v), c.v);
  612. #endif
  613. return a;
  614. }
  615. // Comparison operators
  616. friend Vec operator< (Vec lhs, const Vec& rhs) {
  617. lhs.v = _mm256_cmp_pd(lhs.v, rhs.v, _CMP_LT_OS);
  618. return lhs;
  619. }
  620. friend Vec operator<=(Vec lhs, const Vec& rhs) {
  621. lhs.v = _mm256_cmp_pd(lhs.v, rhs.v, _CMP_LE_OS);
  622. return lhs;
  623. }
  624. friend Vec operator>=(Vec lhs, const Vec& rhs) {
  625. lhs.v = _mm256_cmp_pd(lhs.v, rhs.v, _CMP_GE_OS);
  626. return lhs;
  627. }
  628. friend Vec operator> (Vec lhs, const Vec& rhs) {
  629. lhs.v = _mm256_cmp_pd(lhs.v, rhs.v, _CMP_GT_OS);
  630. return lhs;
  631. }
  632. friend Vec operator==(Vec lhs, const Vec& rhs) {
  633. lhs.v = _mm256_cmp_pd(lhs.v, rhs.v, _CMP_EQ_OS);
  634. return lhs;
  635. }
  636. friend Vec operator!=(Vec lhs, const Vec& rhs) {
  637. lhs.v = _mm256_cmp_pd(lhs.v, rhs.v, _CMP_NEQ_OS);
  638. return lhs;
  639. }
  640. // Bitwise operators
  641. friend Vec operator&(Vec lhs, const Vec& rhs) {
  642. lhs.v = _mm256_and_pd(lhs.v, rhs.v);
  643. return lhs;
  644. }
  645. friend Vec operator^(Vec lhs, const Vec& rhs) {
  646. lhs.v = _mm256_xor_pd(lhs.v, rhs.v);
  647. return lhs;
  648. }
  649. friend Vec operator|(Vec lhs, const Vec& rhs) {
  650. lhs.v = _mm256_or_pd(lhs.v, rhs.v);
  651. return lhs;
  652. }
  653. friend Vec AndNot(Vec lhs, const Vec& rhs) {
  654. lhs.v = _mm256_andnot_pd(rhs.v, lhs.v);
  655. return lhs;
  656. }
  657. // Assignment operators
  658. Vec& operator*=(const Vec& rhs) {
  659. v = _mm256_mul_pd(v, rhs.v);
  660. return *this;
  661. }
  662. Vec& operator+=(const Vec& rhs) {
  663. v = _mm256_add_pd(v, rhs.v);
  664. return *this;
  665. }
  666. Vec& operator-=(const Vec& rhs) {
  667. v = _mm256_sub_pd(v, rhs.v);
  668. return *this;
  669. }
  670. Vec& operator&=(const Vec& rhs) {
  671. v = _mm256_and_pd(v, rhs.v);
  672. return *this;
  673. }
  674. Vec& operator^=(const Vec& rhs) {
  675. v = _mm256_xor_pd(v, rhs.v);
  676. return *this;
  677. }
  678. Vec& operator|=(const Vec& rhs) {
  679. v = _mm256_or_pd(v, rhs.v);
  680. return *this;
  681. }
  682. // Other operators
  683. friend Vec max(Vec lhs, const Vec& rhs) {
  684. lhs.v = _mm256_max_pd(lhs.v, rhs.v);
  685. return lhs;
  686. }
  687. friend Vec min(Vec lhs, const Vec& rhs) {
  688. lhs.v = _mm256_min_pd(lhs.v, rhs.v);
  689. return lhs;
  690. }
  691. friend std::ostream& operator<<(std::ostream& os, const Vec& in) {
  692. union {
  693. VecType vec;
  694. ValueType val[N];
  695. };
  696. vec = in.v;
  697. for (Integer i = 0; i < N; i++) os << val[i] << ' ';
  698. return os;
  699. }
  700. friend Vec approx_rsqrt(const Vec& x) {
  701. Vec r;
  702. r.v = _mm256_cvtps_pd(_mm_rsqrt_ps(_mm256_cvtpd_ps(x.v)));
  703. return r;
  704. }
  705. template <class Vec1, class Vec2> friend Vec1 reinterpret(const Vec2& x);
  706. template <class Vec> friend Vec RoundReal2Real(const Vec& x);
  707. template <class Vec> friend void sincos_intrin(Vec& sinx, Vec& cosx, const Vec& x);
  708. template <class Vec> friend void exp_intrin(Vec& expx, const Vec& x);
  709. private:
  710. VecType v;
  711. };
  712. template <> inline Vec<int64_t,4> reinterpret<Vec<int64_t,4>,Vec<double,4>>(const Vec<double,4>& x){
  713. union {
  714. Vec<int64_t,4> r;
  715. __m256i y;
  716. };
  717. y = _mm256_castpd_si256(x.v);
  718. return r;
  719. }
  720. template <> inline Vec<double,4> RoundReal2Real(const Vec<double,4>& x) {
  721. Vec<double,4> r;
  722. r.v = _mm256_round_pd(x.v,_MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
  723. return r;
  724. }
  725. #ifdef SCTL_HAVE_SVML
  726. template <> inline void sincos_intrin(Vec<double,4>& sinx, Vec<double,4>& cosx, const Vec<double,4>& x) {
  727. sinx.v = _mm256_sin_pd(x.v);
  728. cosx.v = _mm256_cos_pd(x.v);
  729. }
  730. template <> inline void exp_intrin(Vec<double,4>& expx, const Vec<double,4>& x) {
  731. expx.v = _mm256_exp_pd(x.v);
  732. }
  733. #endif
  734. #endif
  735. #ifdef __AVX512F__
  736. template <> class alignas(sizeof(double)*8) Vec<double,8> {
  737. typedef __m512d VecType;
  738. typedef double ValueType;
  739. static constexpr Integer N = 8;
  740. public:
  741. typedef typename GetType<DataType::Integer,TypeTraits<ValueType>::Size>::ValueType IntegerType;
  742. typedef typename GetType<DataType::Real,TypeTraits<ValueType>::Size>::ValueType RealType;
  743. typedef Vec<IntegerType,N> IntegerVec;
  744. typedef Vec<RealType,N> RealVec;
  745. typedef ValueType ScalarType;
  746. static constexpr Integer Size() {
  747. return N;
  748. }
  749. static Vec Zero() {
  750. Vec r;
  751. r.v = _mm512_setzero_pd();
  752. return r;
  753. }
  754. static Vec Load1(ValueType const* p) {
  755. Vec r;
  756. // TODO: different from _m256d, could make it faster?
  757. // r.v = _mm512_broadcast_f64x4(_mm256_broadcast_sd(p));
  758. r.v = _mm512_set1_pd(*p);
  759. return r;
  760. }
  761. static Vec Load(ValueType const* p) {
  762. Vec r;
  763. r.v = _mm512_loadu_pd(p);
  764. return r;
  765. }
  766. static Vec LoadAligned(ValueType const* p) {
  767. Vec r;
  768. r.v = _mm512_load_pd(p);
  769. return r;
  770. }
  771. Vec() = default;
  772. Vec(const ValueType& a) {
  773. v = _mm512_set1_pd(a);
  774. }
  775. Vec(const __mmask8& a) = delete; // disallow implicit conversions
  776. void Store(ValueType* p) const {
  777. _mm512_storeu_pd(p, v);
  778. }
  779. void StoreAligned(ValueType* p) const {
  780. _mm512_store_pd(p, v);
  781. }
  782. // Bitwise NOT
  783. Vec operator~() const {
  784. Vec r;
  785. static constexpr ScalarType Creal = -1.0;
  786. r.v = _mm512_xor_pd(v, _mm512_set1_pd(Creal));
  787. return r;
  788. }
  789. // Unary plus and minus
  790. Vec operator+() const {
  791. return *this;
  792. }
  793. Vec operator-() const {
  794. return Zero() - (*this);
  795. }
  796. // C-style cast
  797. //template <class RetValueType> explicit operator Vec<RetValueType,N>() const {
  798. //}
  799. // Arithmetic operators
  800. friend Vec operator*(Vec lhs, const Vec& rhs) {
  801. lhs.v = _mm512_mul_pd(lhs.v, rhs.v);
  802. return lhs;
  803. }
  804. friend Vec operator+(Vec lhs, const Vec& rhs) {
  805. lhs.v = _mm512_add_pd(lhs.v, rhs.v);
  806. return lhs;
  807. }
  808. friend Vec operator-(Vec lhs, const Vec& rhs) {
  809. lhs.v = _mm512_sub_pd(lhs.v, rhs.v);
  810. return lhs;
  811. }
  812. friend Vec FMA(Vec a, const Vec& b, const Vec& c) {
  813. a.v = _mm512_fmadd_pd(a.v, b.v, c.v);
  814. //a.v = _mm512_add_pd(_mm512_mul_pd(a.v, b.v), c.v);
  815. return a;
  816. }
  817. // Comparison operators
  818. //friend Vec operator< (Vec lhs, const Vec& rhs) {
  819. // lhs.v = _mm512_castsi512_pd(_mm512_movm_epi64(_mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_LT_OS)));
  820. // return lhs;
  821. //}
  822. //friend Vec operator<=(Vec lhs, const Vec& rhs) {
  823. // lhs.v = _mm512_castsi512_pd(_mm512_movm_epi64(_mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_LE_OS)));
  824. // return lhs;
  825. //}
  826. //friend Vec operator>=(Vec lhs, const Vec& rhs) {
  827. // lhs.v = _mm512_castsi512_pd(_mm512_movm_epi64(_mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_GE_OS)));
  828. // return lhs;
  829. //}
  830. //friend Vec operator> (Vec lhs, const Vec& rhs) {
  831. // lhs.v = _mm512_castsi512_pd(_mm512_movm_epi64(_mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_GT_OS)));
  832. // return lhs;
  833. //}
  834. //friend Vec operator==(Vec lhs, const Vec& rhs) {
  835. // lhs.v = _mm512_castsi512_pd(_mm512_movm_epi64(_mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_EQ_OS)));
  836. // return lhs;
  837. //}
  838. //friend Vec operator!=(Vec lhs, const Vec& rhs) {
  839. // lhs.v = _mm512_castsi512_pd(_mm512_movm_epi64(_mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_NEQ_OS)));
  840. // return lhs;
  841. //}
  842. friend __mmask8 operator< (Vec lhs, const Vec& rhs) {
  843. return _mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_LT_OS);
  844. }
  845. friend __mmask8 operator<=(Vec lhs, const Vec& rhs) {
  846. return _mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_LE_OS);
  847. }
  848. friend __mmask8 operator>=(Vec lhs, const Vec& rhs) {
  849. return _mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_GE_OS);
  850. }
  851. friend __mmask8 operator> (Vec lhs, const Vec& rhs) {
  852. return _mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_GT_OS);
  853. }
  854. friend __mmask8 operator==(Vec lhs, const Vec& rhs) {
  855. return _mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_EQ_OS);
  856. }
  857. friend __mmask8 operator!=(Vec lhs, const Vec& rhs) {
  858. return _mm512_cmp_pd_mask(lhs.v, rhs.v, _CMP_NEQ_OS);
  859. }
  860. // Bitwise operators
  861. friend Vec operator&(Vec lhs, const Vec& rhs) {
  862. lhs.v = _mm512_and_pd(lhs.v, rhs.v);
  863. return lhs;
  864. }
  865. friend Vec operator^(Vec lhs, const Vec& rhs) {
  866. lhs.v = _mm512_xor_pd(lhs.v, rhs.v);
  867. return lhs;
  868. }
  869. friend Vec operator|(Vec lhs, const Vec& rhs) {
  870. lhs.v = _mm512_or_pd(lhs.v, rhs.v);
  871. return lhs;
  872. }
  873. friend Vec AndNot(Vec lhs, const Vec& rhs) {
  874. lhs.v = _mm512_andnot_pd(rhs.v, lhs.v);
  875. return lhs;
  876. }
  877. friend Vec operator&(Vec lhs, const __mmask8& rhs) {
  878. lhs.v = _mm512_maskz_mov_pd(rhs, lhs.v);
  879. return lhs;
  880. }
  881. friend Vec AndNot(Vec lhs, const __mmask8& rhs) {
  882. lhs.v = _mm512_mask_mov_pd(lhs.v, rhs, _mm512_setzero_pd());
  883. return lhs;
  884. }
  885. // Assignment operators
  886. Vec& operator*=(const Vec& rhs) {
  887. v = _mm512_mul_pd(v, rhs.v);
  888. return *this;
  889. }
  890. Vec& operator+=(const Vec& rhs) {
  891. v = _mm512_add_pd(v, rhs.v);
  892. return *this;
  893. }
  894. Vec& operator-=(const Vec& rhs) {
  895. v = _mm512_sub_pd(v, rhs.v);
  896. return *this;
  897. }
  898. Vec& operator&=(const Vec& rhs) {
  899. v = _mm512_and_pd(v, rhs.v);
  900. return *this;
  901. }
  902. Vec& operator^=(const Vec& rhs) {
  903. v = _mm512_xor_pd(v, rhs.v);
  904. return *this;
  905. }
  906. Vec& operator|=(const Vec& rhs) {
  907. v = _mm512_or_pd(v, rhs.v);
  908. return *this;
  909. }
  910. Vec& operator&=(const __mmask8& rhs) {
  911. v = _mm512_maskz_mov_pd(rhs, v);
  912. return *this;
  913. }
  914. // Other operators
  915. friend Vec max(Vec lhs, const Vec& rhs) {
  916. lhs.v = _mm512_max_pd(lhs.v, rhs.v);
  917. return lhs;
  918. }
  919. friend Vec min(Vec lhs, const Vec& rhs) {
  920. lhs.v = _mm512_min_pd(lhs.v, rhs.v);
  921. return lhs;
  922. }
  923. friend std::ostream& operator<<(std::ostream& os, const Vec& in) {
  924. union {
  925. VecType vec;
  926. ValueType val[N];
  927. };
  928. vec = in.v;
  929. for (Integer i = 0; i < N; i++) os << val[i] << ' ';
  930. return os;
  931. }
  932. friend Vec approx_rsqrt(const Vec& x) {
  933. Vec r;
  934. r.v = _mm512_cvtps_pd(_mm256_rsqrt_ps(_mm512_cvtpd_ps(x.v)));
  935. return r;
  936. }
  937. template <class Vec1, class Vec2> friend Vec1 reinterpret(const Vec2& x);
  938. template <class Vec> friend Vec RoundReal2Real(const Vec& x);
  939. template <class Vec> friend void sincos_intrin(Vec& sinx, Vec& cosx, const Vec& x);
  940. template <class Vec> friend void exp_intrin(Vec& expx, const Vec& x);
  941. private:
  942. VecType v;
  943. };
  944. template <> inline Vec<int64_t,8> reinterpret<Vec<int64_t,8>,Vec<double,8>>(const Vec<double,8>& x){
  945. union {
  946. Vec<int64_t,8> r;
  947. __m512i y;
  948. };
  949. y = _mm512_castpd_si512(x.v);
  950. return r;
  951. }
  952. template <> inline Vec<double,8> RoundReal2Real(const Vec<double,8>& x) {
  953. Vec<double,8> r;
  954. // TODO: need double check
  955. r.v = _mm512_roundscale_pd(x.v,_MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
  956. return r;
  957. }
  958. #ifdef SCTL_HAVE_SVML
  959. template <> inline void sincos_intrin(Vec<double,8>& sinx, Vec<double,8>& cosx, const Vec<double,8>& x) {
  960. sinx.v = _mm512_sin_pd(x.v);
  961. cosx.v = _mm512_cos_pd(x.v);
  962. }
  963. template <> inline void exp_intrin(Vec<double,8>& expx, const Vec<double,8>& x) {
  964. expx.v = _mm512_exp_pd(x.v);
  965. }
  966. #endif
  967. #endif
  968. }
  969. #endif //_SCTL_VEC_WRAPPER_HPP_