kernel_functions.hpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. #ifndef _SCTL_KERNEL_FUNCTIONS_HPP_
  2. #define _SCTL_KERNEL_FUNCTIONS_HPP_
  3. #include <sctl/common.hpp>
  4. #include SCTL_INCLUDE(vec.hpp)
  5. #include SCTL_INCLUDE(mem_mgr.hpp)
  6. namespace SCTL_NAMESPACE {
  7. template <class ValueType> class Matrix;
  8. template <class ValueType> class Vector;
  9. template <class uKernel, Integer KDIM0, Integer KDIM1, Integer DIM, Integer N_DIM> struct uKerHelper {
  10. template <Integer digits, class VecType> static void MatEval(VecType (&u)[KDIM0][KDIM1], const VecType (&r)[DIM], const VecType (&n)[N_DIM], const void* ctx_ptr) {
  11. uKernel::template uKerMatrix<digits>(u, r, n, ctx_ptr);
  12. }
  13. };
  14. template <class uKernel, Integer KDIM0, Integer KDIM1, Integer DIM> struct uKerHelper<uKernel,KDIM0,KDIM1,DIM,0> {
  15. template <Integer digits, class VecType, class NormalType> static void MatEval(VecType (&u)[KDIM0][KDIM1], const VecType (&r)[DIM], const NormalType& n, const void* ctx_ptr) {
  16. uKernel::template uKerMatrix<digits>(u, r, ctx_ptr);
  17. }
  18. };
  19. template <class uKernel> class GenericKernel : public uKernel {
  20. template <class VecType, Integer K0, Integer K1, Integer D, class ...T> static constexpr Integer get_DIM (void (*uKer)(VecType (&u)[K0][K1], const VecType (&r)[D], T... args)) { return D; }
  21. template <class VecType, Integer K0, Integer K1, Integer D, class ...T> static constexpr Integer get_KDIM0(void (*uKer)(VecType (&u)[K0][K1], const VecType (&r)[D], T... args)) { return K0; }
  22. template <class VecType, Integer K0, Integer K1, Integer D, class ...T> static constexpr Integer get_KDIM1(void (*uKer)(VecType (&u)[K0][K1], const VecType (&r)[D], T... args)) { return K1; }
  23. static constexpr Integer DIM = get_DIM (uKernel::template uKerMatrix<0,Vec<double,1>>);
  24. static constexpr Integer KDIM0 = get_KDIM0(uKernel::template uKerMatrix<0,Vec<double,1>>);
  25. static constexpr Integer KDIM1 = get_KDIM1(uKernel::template uKerMatrix<0,Vec<double,1>>);
  26. template <Integer cnt> static constexpr Integer argsize_helper() { return 0; }
  27. template <Integer cnt, class T, class ...T1> static constexpr Integer argsize_helper() { return (cnt == 0 ? sizeof(T) : 0) + argsize_helper<cnt-1, T1...>(); }
  28. template <Integer idx, class ...T1> static constexpr Integer argsize(void (uKer)(T1... args)) { return argsize_helper<idx, T1...>(); }
  29. template <Integer cnt> static constexpr Integer argcount_helper() { return cnt; }
  30. template <Integer cnt, class T, class ...T1> static constexpr Integer argcount_helper() { return argcount_helper<cnt+1, T1...>(); }
  31. template <class ...T1> static constexpr Integer argcount(void (uKer)(T1... args)) { return argcount_helper<0, T1...>(); }
  32. static constexpr Integer ARGCNT = argcount(uKernel::template uKerMatrix<0,Vec<double,1>>);
  33. static constexpr Integer N_DIM = (ARGCNT > 3 ? argsize<2>(uKernel::template uKerMatrix<0,Vec<double,1>>)/sizeof(Vec<double,1>) : 0);
  34. static constexpr Integer N_DIM_ = (N_DIM?N_DIM:1); // non-zero
  35. public:
  36. GenericKernel() : ctx_ptr(nullptr) {}
  37. static constexpr Integer CoordDim() {
  38. return DIM;
  39. }
  40. static constexpr Integer NormalDim() {
  41. return N_DIM;
  42. }
  43. static constexpr Integer SrcDim() {
  44. return KDIM0;
  45. }
  46. static constexpr Integer TrgDim() {
  47. return KDIM1;
  48. }
  49. const void* GetCtxPtr() const {
  50. return ctx_ptr;
  51. }
  52. template <class Real, bool enable_openmp> static void Eval(Vector<Real>& v_trg, const Vector<Real>& r_trg, const Vector<Real>& r_src, const Vector<Real>& n_src, const Vector<Real>& v_src, Integer digits, ConstIterator<char> self) {
  53. if (digits < 8) {
  54. if (digits < 4) {
  55. if (digits == -1) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp,-1>(v_trg, r_trg, r_src, n_src, v_src);
  56. if (digits == 0) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp, 0>(v_trg, r_trg, r_src, n_src, v_src);
  57. if (digits == 1) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp, 1>(v_trg, r_trg, r_src, n_src, v_src);
  58. if (digits == 2) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp, 2>(v_trg, r_trg, r_src, n_src, v_src);
  59. if (digits == 3) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp, 3>(v_trg, r_trg, r_src, n_src, v_src);
  60. } else {
  61. if (digits == 7) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp, 7>(v_trg, r_trg, r_src, n_src, v_src);
  62. if (digits == 6) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp, 6>(v_trg, r_trg, r_src, n_src, v_src);
  63. if (digits == 5) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp, 5>(v_trg, r_trg, r_src, n_src, v_src);
  64. if (digits == 4) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp, 4>(v_trg, r_trg, r_src, n_src, v_src);
  65. }
  66. } else {
  67. if (digits < 12) {
  68. if (digits == 8) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp, 8>(v_trg, r_trg, r_src, n_src, v_src);
  69. if (digits == 9) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp, 9>(v_trg, r_trg, r_src, n_src, v_src);
  70. if (digits == 10) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp,10>(v_trg, r_trg, r_src, n_src, v_src);
  71. if (digits == 11) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp,11>(v_trg, r_trg, r_src, n_src, v_src);
  72. } else {
  73. if (digits == 12) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp,12>(v_trg, r_trg, r_src, n_src, v_src);
  74. if (digits == 13) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp,13>(v_trg, r_trg, r_src, n_src, v_src);
  75. if (digits == 14) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp,14>(v_trg, r_trg, r_src, n_src, v_src);
  76. if (digits == 15) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp,15>(v_trg, r_trg, r_src, n_src, v_src);
  77. if (digits >= 16) ((ConstIterator<GenericKernel<uKernel>>)self)->template Eval<Real, enable_openmp,-1>(v_trg, r_trg, r_src, n_src, v_src);
  78. }
  79. }
  80. }
  81. template <class Real, bool enable_openmp=false, Integer digits=-1> void Eval(Vector<Real>& v_trg, const Vector<Real>& r_trg, const Vector<Real>& r_src, const Vector<Real>& n_src, const Vector<Real>& v_src) const {
  82. static constexpr Integer digits_ = (digits==-1 ? (Integer)(TypeTraits<Real>::SigBits*0.3010299957) : digits);
  83. static constexpr Integer VecLen = DefaultVecLen<Real>();
  84. using RealVec = Vec<Real, VecLen>;
  85. auto uKerEval = [this](RealVec (&vt)[KDIM1], const RealVec (&xt)[DIM], const RealVec (&xs)[DIM], const RealVec (&ns)[N_DIM_], const RealVec (&vs)[KDIM0]) {
  86. RealVec dX[DIM], U[KDIM0][KDIM1];
  87. for (Integer i = 0; i < DIM; i++) dX[i] = xt[i] - xs[i];
  88. uKerMatrix<digits_>(U, dX, ns, ctx_ptr);
  89. for (Integer k0 = 0; k0 < KDIM0; k0++) {
  90. for (Integer k1 = 0; k1 < KDIM1; k1++) {
  91. vt[k1] = FMA(U[k0][k1], vs[k0], vt[k1]);
  92. }
  93. }
  94. };
  95. const Long Ns = r_src.Dim() / DIM;
  96. const Long Nt = r_trg.Dim() / DIM;
  97. SCTL_ASSERT(r_trg.Dim() == Nt*DIM);
  98. SCTL_ASSERT(r_src.Dim() == Ns*DIM);
  99. SCTL_ASSERT(v_src.Dim() == Ns*KDIM0);
  100. SCTL_ASSERT(n_src.Dim() == Ns*N_DIM || !N_DIM);
  101. if (v_trg.Dim() != Nt*KDIM1) {
  102. v_trg.ReInit(Nt*KDIM1);
  103. v_trg.SetZero();
  104. }
  105. const Long NNt = ((Nt + VecLen - 1) / VecLen) * VecLen;
  106. if (NNt == VecLen) {
  107. RealVec xt[DIM], vt[KDIM1], xs[DIM], ns[N_DIM_], vs[KDIM0];
  108. for (Integer k = 0; k < KDIM1; k++) vt[k] = RealVec::Zero();
  109. for (Integer k = 0; k < DIM; k++) {
  110. alignas(sizeof(RealVec)) StaticArray<Real,VecLen> Xt;
  111. RealVec::Zero().StoreAligned(&Xt[0]);
  112. for (Integer i = 0; i < Nt; i++) Xt[i] = r_trg[i*DIM+k];
  113. xt[k] = RealVec::LoadAligned(&Xt[0]);
  114. }
  115. for (Long s = 0; s < Ns; s++) {
  116. for (Integer k = 0; k < DIM; k++) xs[k] = RealVec::Load1(&r_src[s*DIM+k]);
  117. for (Integer k = 0; k < N_DIM; k++) ns[k] = RealVec::Load1(&n_src[s*N_DIM+k]);
  118. for (Integer k = 0; k < KDIM0; k++) vs[k] = RealVec::Load1(&v_src[s*KDIM0+k]);
  119. uKerEval(vt, xt, xs, ns, vs);
  120. }
  121. for (Integer k = 0; k < KDIM1; k++) {
  122. alignas(sizeof(RealVec)) StaticArray<Real,VecLen> out;
  123. vt[k].StoreAligned(&out[0]);
  124. for (Long t = 0; t < Nt; t++) {
  125. v_trg[t*KDIM1+k] += out[t] * uKernel::template uKerScaleFactor<Real>();
  126. }
  127. }
  128. } else {
  129. const Matrix<Real> Xs_(Ns, DIM, (Iterator<Real>)r_src.begin(), false);
  130. const Matrix<Real> Ns_(Ns, N_DIM, (Iterator<Real>)n_src.begin(), false);
  131. const Matrix<Real> Vs_(Ns, KDIM0, (Iterator<Real>)v_src.begin(), false);
  132. Matrix<Real> Xt_(DIM, NNt), Vt_(KDIM1, NNt);
  133. for (Long k = 0; k < DIM; k++) { // Set Xt_
  134. for (Long i = 0; i < Nt; i++) {
  135. Xt_[k][i] = r_trg[i*DIM+k];
  136. }
  137. for (Long i = Nt; i < NNt; i++) {
  138. Xt_[k][i] = 0;
  139. }
  140. }
  141. if (enable_openmp) { // Compute Vt_
  142. #pragma omp parallel for schedule(static)
  143. for (Long t = 0; t < NNt; t += VecLen) {
  144. RealVec xt[DIM], vt[KDIM1], xs[DIM], ns[N_DIM_], vs[KDIM0];
  145. for (Integer k = 0; k < KDIM1; k++) vt[k] = RealVec::Zero();
  146. for (Integer k = 0; k < DIM; k++) xt[k] = RealVec::LoadAligned(&Xt_[k][t]);
  147. for (Long s = 0; s < Ns; s++) {
  148. for (Integer k = 0; k < DIM; k++) xs[k] = RealVec::Load1(&Xs_[s][k]);
  149. for (Integer k = 0; k < N_DIM; k++) ns[k] = RealVec::Load1(&Ns_[s][k]);
  150. for (Integer k = 0; k < KDIM0; k++) vs[k] = RealVec::Load1(&Vs_[s][k]);
  151. uKerEval(vt, xt, xs, ns, vs);
  152. }
  153. for (Integer k = 0; k < KDIM1; k++) vt[k].StoreAligned(&Vt_[k][t]);
  154. }
  155. } else {
  156. for (Long t = 0; t < NNt; t += VecLen) {
  157. RealVec xt[DIM], vt[KDIM1], xs[DIM], ns[N_DIM_], vs[KDIM0];
  158. for (Integer k = 0; k < KDIM1; k++) vt[k] = RealVec::Zero();
  159. for (Integer k = 0; k < DIM; k++) xt[k] = RealVec::LoadAligned(&Xt_[k][t]);
  160. for (Long s = 0; s < Ns; s++) {
  161. for (Integer k = 0; k < DIM; k++) xs[k] = RealVec::Load1(&Xs_[s][k]);
  162. for (Integer k = 0; k < N_DIM; k++) ns[k] = RealVec::Load1(&Ns_[s][k]);
  163. for (Integer k = 0; k < KDIM0; k++) vs[k] = RealVec::Load1(&Vs_[s][k]);
  164. uKerEval(vt, xt, xs, ns, vs);
  165. }
  166. for (Integer k = 0; k < KDIM1; k++) vt[k].StoreAligned(&Vt_[k][t]);
  167. }
  168. }
  169. for (Long k = 0; k < KDIM1; k++) { // v_trg += Vt_
  170. for (Long i = 0; i < Nt; i++) {
  171. v_trg[i*KDIM1+k] += Vt_[k][i] * uKernel::template uKerScaleFactor<Real>();
  172. }
  173. }
  174. }
  175. }
  176. template <class Real, bool enable_openmp=false, Integer digits=-1> void KernelMatrix(Matrix<Real>& M, const Vector<Real>& Xt, const Vector<Real>& Xs, const Vector<Real>& Xn) const {
  177. static constexpr Integer digits_ = (digits==-1 ? (Integer)(TypeTraits<Real>::SigBits*0.3010299957) : digits);
  178. static constexpr Integer VecLen = DefaultVecLen<Real>();
  179. using VecType = Vec<Real, VecLen>;
  180. const Long Ns = Xs.Dim()/DIM;
  181. const Long Nt = Xt.Dim()/DIM;
  182. if (M.Dim(0) != Ns*KDIM0 || M.Dim(1) != Nt*KDIM1) {
  183. M.ReInit(Ns*KDIM0, Nt*KDIM1);
  184. M.SetZero();
  185. }
  186. if (Xt.Dim() == DIM) {
  187. alignas(sizeof(VecType)) StaticArray<Real,VecLen> Xs_[DIM];
  188. alignas(sizeof(VecType)) StaticArray<Real,VecLen> Xn_[N_DIM_];
  189. alignas(sizeof(VecType)) StaticArray<Real,VecLen> M_[KDIM0*KDIM1];
  190. for (Integer k = 0; k < DIM; k++) VecType::Zero().StoreAligned(&Xs_[k][0]);
  191. for (Integer k = 0; k < N_DIM; k++) VecType::Zero().StoreAligned(&Xn_[k][0]);
  192. VecType vec_Xt[DIM], vec_dX[DIM], vec_Xn[N_DIM_], vec_M[KDIM0][KDIM1];
  193. for (Integer k = 0; k < DIM; k++) { // Set vec_Xt
  194. vec_Xt[k] = VecType::Load1(&Xt[k]);
  195. }
  196. for (Long i0 = 0; i0 < Ns; i0+=VecLen) {
  197. const Long Ns_ = std::min<Long>(VecLen, Ns-i0);
  198. for (Long i1 = 0; i1 < Ns_; i1++) { // Set Xs_
  199. for (Long k = 0; k < DIM; k++) {
  200. Xs_[k][i1] = Xs[(i0+i1)*DIM+k];
  201. }
  202. }
  203. for (Long k = 0; k < DIM; k++) { // Set vec_dX
  204. vec_dX[k] = vec_Xt[k] - VecType::LoadAligned(&Xs_[k][0]);
  205. }
  206. if (N_DIM) { // Set vec_Xn
  207. for (Long i1 = 0; i1 < Ns_; i1++) { // Set Xn_
  208. for (Long k = 0; k < N_DIM; k++) {
  209. Xn_[k][i1] = Xn[(i0+i1)*N_DIM+k];
  210. }
  211. }
  212. for (Long k = 0; k < N_DIM; k++) { // Set vec_Xn
  213. vec_Xn[k] = VecType::LoadAligned(&Xn_[k][0]);
  214. }
  215. }
  216. uKerMatrix<digits_>(vec_M, vec_dX, vec_Xn, ctx_ptr);
  217. for (Integer k0 = 0; k0 < KDIM0; k0++) { // Set M_
  218. for (Integer k1 = 0; k1 < KDIM1; k1++) {
  219. vec_M[k0][k1].StoreAligned(&M_[k0*KDIM1+k1][0]);
  220. }
  221. }
  222. for (Long i1 = 0; i1 < Ns_; i1++) { // Set M
  223. for (Integer k0 = 0; k0 < KDIM0; k0++) {
  224. for (Integer k1 = 0; k1 < KDIM1; k1++) {
  225. M[(i0+i1)*KDIM0+k0][k1] = M_[k0*KDIM1+k1][i1] * uKernel::template uKerScaleFactor<Real>();
  226. }
  227. }
  228. }
  229. }
  230. } else if (Xs.Dim() == DIM) {
  231. alignas(sizeof(VecType)) StaticArray<Real,VecLen> Xt_[DIM];
  232. alignas(sizeof(VecType)) StaticArray<Real,VecLen> M_[KDIM0*KDIM1];
  233. for (Integer k = 0; k < DIM; k++) VecType::Zero().StoreAligned(&Xt_[k][0]);
  234. VecType vec_Xs[DIM], vec_dX[DIM], vec_Xn[N_DIM_], vec_M[KDIM0][KDIM1];
  235. for (Integer k = 0; k < DIM; k++) { // Set vec_Xs
  236. vec_Xs[k] = VecType::Load1(&Xs[k]);
  237. }
  238. for (Long k = 0; k < N_DIM; k++) { // Set vec_Xn
  239. vec_Xn[k] = VecType::Load1(&Xn[k]);
  240. }
  241. for (Long i0 = 0; i0 < Nt; i0+=VecLen) {
  242. const Long Nt_ = std::min<Long>(VecLen, Nt-i0);
  243. for (Long i1 = 0; i1 < Nt_; i1++) { // Set Xt_
  244. for (Long k = 0; k < DIM; k++) {
  245. Xt_[k][i1] = Xt[(i0+i1)*DIM+k];
  246. }
  247. }
  248. for (Long k = 0; k < DIM; k++) { // Set vec_dX
  249. vec_dX[k] = VecType::LoadAligned(&Xt_[k][0]) - vec_Xs[k];
  250. }
  251. uKerMatrix<digits_>(vec_M, vec_dX, vec_Xn, ctx_ptr);
  252. for (Integer k0 = 0; k0 < KDIM0; k0++) { // Set M_
  253. for (Integer k1 = 0; k1 < KDIM1; k1++) {
  254. vec_M[k0][k1].StoreAligned(&M_[k0*KDIM1+k1][0]);
  255. }
  256. }
  257. for (Long i1 = 0; i1 < Nt_; i1++) { // Set M
  258. for (Integer k0 = 0; k0 < KDIM0; k0++) {
  259. for (Integer k1 = 0; k1 < KDIM1; k1++) {
  260. M[k0][(i0+i1)*KDIM1+k1] = M_[k0*KDIM1+k1][i1] * uKernel::template uKerScaleFactor<Real>();
  261. }
  262. }
  263. }
  264. }
  265. } else {
  266. if (enable_openmp) {
  267. #pragma omp parallel for schedule(static)
  268. for (Long i = 0; i < Ns; i++) {
  269. Matrix<Real> M_(KDIM0, Nt*KDIM1, M.begin() + i*KDIM0*Nt*KDIM1, false);
  270. const Vector<Real> Xs_(DIM, (Iterator<Real>)Xs.begin() + i*DIM, false);
  271. const Vector<Real> Xn_(N_DIM, (Iterator<Real>)Xn.begin() + i*N_DIM, false);
  272. KernelMatrix<Real,enable_openmp,digits>(M_, Xt, Xs_, Xn_);
  273. }
  274. } else {
  275. for (Long i = 0; i < Ns; i++) {
  276. Matrix<Real> M_(KDIM0, Nt*KDIM1, M.begin() + i*KDIM0*Nt*KDIM1, false);
  277. const Vector<Real> Xs_(DIM, (Iterator<Real>)Xs.begin() + i*DIM, false);
  278. const Vector<Real> Xn_(N_DIM, (Iterator<Real>)Xn.begin() + i*N_DIM, false);
  279. KernelMatrix<Real,enable_openmp,digits>(M_, Xt, Xs_, Xn_);
  280. }
  281. }
  282. }
  283. }
  284. template <Integer digits, class VecType, class NormalType> static void uKerMatrix(VecType (&u)[KDIM0][KDIM1], const VecType (&r)[DIM], const NormalType& n, const void* ctx_ptr) {
  285. uKerHelper<uKernel,KDIM0,KDIM1,DIM,N_DIM>::template MatEval<digits>(u, r, n, ctx_ptr);
  286. };
  287. private:
  288. void* ctx_ptr;
  289. };
  290. namespace kernel_impl {
  291. struct Laplace3D_FxU {
  292. static const std::string& Name() {
  293. static const std::string name = "Laplace3D-FxU";
  294. return name;
  295. }
  296. static constexpr Integer FLOPS() {
  297. return 6;
  298. }
  299. template <class Real> static constexpr Real uKerScaleFactor() {
  300. return 1 / (4 * const_pi<Real>());
  301. }
  302. template <Integer digits, class VecType> static void uKerMatrix(VecType (&u)[1][1], const VecType (&r)[3], const void* ctx_ptr) {
  303. VecType r2 = r[0]*r[0]+r[1]*r[1]+r[2]*r[2];
  304. VecType rinv = approx_rsqrt<digits>(r2, r2 > VecType::Zero());
  305. u[0][0] = rinv;
  306. }
  307. };
  308. struct Laplace3D_DxU {
  309. static const std::string& Name() {
  310. static const std::string name = "Laplace3D-DxU";
  311. return name;
  312. }
  313. static constexpr Integer FLOPS() {
  314. return 14;
  315. }
  316. template <class Real> static constexpr Real uKerScaleFactor() {
  317. return 1 / (4 * const_pi<Real>());
  318. }
  319. template <Integer digits, class VecType> static void uKerMatrix(VecType (&u)[1][1], const VecType (&r)[3], const VecType (&n)[3], const void* ctx_ptr) {
  320. VecType r2 = r[0]*r[0]+r[1]*r[1]+r[2]*r[2];
  321. VecType rinv = approx_rsqrt<digits>(r2, r2 > VecType::Zero());
  322. VecType rdotn = r[0]*n[0] + r[1]*n[1] + r[2]*n[2];
  323. VecType rinv3 = rinv * rinv * rinv;
  324. u[0][0] = rdotn * rinv3;
  325. }
  326. };
  327. struct Laplace3D_FxdU {
  328. static const std::string& Name() {
  329. static const std::string name = "Laplace3D-FxdU";
  330. return name;
  331. }
  332. static constexpr Integer FLOPS() {
  333. return 11;
  334. }
  335. template <class Real> static constexpr Real uKerScaleFactor() {
  336. return -1 / (4 * const_pi<Real>());
  337. }
  338. template <Integer digits, class VecType> static void uKerMatrix(VecType (&u)[1][3], const VecType (&r)[3], const void* ctx_ptr) {
  339. VecType r2 = r[0]*r[0]+r[1]*r[1]+r[2]*r[2];
  340. VecType rinv = approx_rsqrt<digits>(r2, r2 > VecType::Zero());
  341. VecType rinv3 = rinv * rinv * rinv;
  342. u[0][0] = r[0] * rinv3;
  343. u[0][1] = r[1] * rinv3;
  344. u[0][2] = r[2] * rinv3;
  345. }
  346. };
  347. struct Stokes3D_FxU {
  348. static const std::string& Name() {
  349. static const std::string name = "Stokes3D-FxU";
  350. return name;
  351. }
  352. static constexpr Integer FLOPS() {
  353. return 23;
  354. }
  355. template <class Real> static constexpr Real uKerScaleFactor() {
  356. return 1 / (8 * const_pi<Real>());
  357. }
  358. template <Integer digits, class VecType> static void uKerMatrix(VecType (&u)[3][3], const VecType (&r)[3], const void* ctx_ptr) {
  359. VecType r2 = r[0]*r[0]+r[1]*r[1]+r[2]*r[2];
  360. VecType rinv = approx_rsqrt<digits>(r2, r2 > VecType::Zero());
  361. VecType rinv3 = rinv*rinv*rinv;
  362. for (Integer i = 0; i < 3; i++) {
  363. for (Integer j = 0; j < 3; j++) {
  364. u[i][j] = (i==j ? rinv : VecType::Zero()) + r[i]*r[j]*rinv3;
  365. }
  366. }
  367. }
  368. };
  369. struct Stokes3D_DxU {
  370. static const std::string& Name() {
  371. static const std::string name = "Stokes3D-DxU";
  372. return name;
  373. }
  374. static constexpr Integer FLOPS() {
  375. return 26;
  376. }
  377. template <class Real> static constexpr Real uKerScaleFactor() {
  378. return 3 / (4 * const_pi<Real>());
  379. }
  380. template <Integer digits, class VecType> static void uKerMatrix(VecType (&u)[3][3], const VecType (&r)[3], const VecType (&n)[3], const void* ctx_ptr) {
  381. VecType r2 = r[0]*r[0]+r[1]*r[1]+r[2]*r[2];
  382. VecType rinv = approx_rsqrt<digits>(r2, r2 > VecType::Zero());
  383. VecType rinv2 = rinv*rinv;
  384. VecType rinv5 = rinv2*rinv2*rinv;
  385. VecType rdotn_rinv5 = (r[0]*n[0] + r[1]*n[1] + r[2]*n[2])*rinv5;
  386. for (Integer i = 0; i < 3; i++) {
  387. for (Integer j = 0; j < 3; j++) {
  388. u[i][j] = r[i]*r[j]*rdotn_rinv5;
  389. }
  390. }
  391. }
  392. };
  393. struct Stokes3D_FxT {
  394. static const std::string& Name() {
  395. static const std::string name = "Stokes3D-FxT";
  396. return name;
  397. }
  398. static constexpr Integer FLOPS() {
  399. return 39;
  400. }
  401. template <class Real> static constexpr Real uKerScaleFactor() {
  402. return -3 / (4 * const_pi<Real>());
  403. }
  404. template <Integer digits, class VecType> static void uKerMatrix(VecType (&u)[3][9], const VecType (&r)[3], const void* ctx_ptr) {
  405. VecType r2 = r[0]*r[0]+r[1]*r[1]+r[2]*r[2];
  406. VecType rinv = approx_rsqrt<digits>(r2, r2 > VecType::Zero());
  407. VecType rinv2 = rinv*rinv;
  408. VecType rinv5 = rinv2*rinv2*rinv;
  409. for (Integer i = 0; i < 3; i++) {
  410. for (Integer j = 0; j < 3; j++) {
  411. for (Integer k = 0; k < 3; k++) {
  412. u[i][j*3+k] = r[i]*r[j]*r[k]*rinv5;
  413. }
  414. }
  415. }
  416. }
  417. };
  418. struct Stokes3D_FSxU {
  419. static const std::string& Name() {
  420. static const std::string name = "Stokes3D-FSxU";
  421. return name;
  422. }
  423. static constexpr Integer FLOPS() {
  424. return 26;
  425. }
  426. template <class Real> static constexpr Real uKerScaleFactor() {
  427. return 1 / (8 * const_pi<Real>());
  428. }
  429. template <Integer digits, class VecType> static void uKerMatrix(VecType (&u)[4][3], const VecType (&r)[3], const void* ctx_ptr) {
  430. VecType r2 = r[0]*r[0]+r[1]*r[1]+r[2]*r[2];
  431. VecType rinv = approx_rsqrt<digits>(r2, r2 > VecType::Zero());
  432. VecType rinv3 = rinv*rinv*rinv;
  433. for (Integer i = 0; i < 3; i++) {
  434. for (Integer j = 0; j < 3; j++) {
  435. u[i][j] = (i==j ? rinv : VecType::Zero()) + r[i]*r[j]*rinv3;
  436. }
  437. }
  438. for (Integer j = 0; j < 3; j++) {
  439. u[3][j] = r[j]*rinv3;
  440. }
  441. }
  442. };
  443. struct Stokes3D_FxUP {
  444. static const std::string& Name() {
  445. static const std::string name = "Stokes3D-FxUP";
  446. return name;
  447. }
  448. static constexpr Integer FLOPS() {
  449. return 26;
  450. }
  451. template <class Real> static constexpr Real uKerScaleFactor() {
  452. return 1 / (8 * const_pi<Real>());
  453. }
  454. template <Integer digits, class VecType> static void uKerMatrix(VecType (&u)[3][4], const VecType (&r)[3], const void* ctx_ptr) {
  455. VecType r2 = r[0]*r[0]+r[1]*r[1]+r[2]*r[2];
  456. VecType rinv = approx_rsqrt<digits>(r2, r2 > VecType::Zero());
  457. VecType rinv3 = rinv*rinv*rinv;
  458. for (Integer i = 0; i < 3; i++) {
  459. for (Integer j = 0; j < 3; j++) {
  460. u[i][j] = (i==j ? rinv : VecType::Zero()) + r[i]*r[j]*rinv3;
  461. }
  462. }
  463. for (Integer i = 0; i < 3; i++) {
  464. u[i][3] = r[i]*rinv3;
  465. }
  466. }
  467. };
  468. } // namespace kernel_impl
  469. struct Laplace3D_FxU : public GenericKernel<kernel_impl::Laplace3D_FxU> {};
  470. struct Laplace3D_DxU : public GenericKernel<kernel_impl::Laplace3D_DxU> {};
  471. struct Laplace3D_FxdU : public GenericKernel<kernel_impl::Laplace3D_FxdU>{};
  472. struct Stokes3D_FxU : public GenericKernel<kernel_impl::Stokes3D_FxU> {};
  473. struct Stokes3D_DxU : public GenericKernel<kernel_impl::Stokes3D_DxU> {};
  474. struct Stokes3D_FxT : public GenericKernel<kernel_impl::Stokes3D_FxT> {};
  475. struct Stokes3D_FSxU : public GenericKernel<kernel_impl::Stokes3D_FSxU> {}; // for FMM translations - M2M, M2L, M2T
  476. struct Stokes3D_FxUP : public GenericKernel<kernel_impl::Stokes3D_FxUP> {};
  477. } // end namespace
  478. #endif //_SCTL_KERNEL_FUNCTIONS_HPP_