lagrange-interp.txx 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. #include SCTL_INCLUDE(vector.hpp)
  2. namespace SCTL_NAMESPACE {
  3. template <class Real> void LagrangeInterp<Real>::Interpolate(Vector<Real>& wts, const Vector<Real>& src_nds, const Vector<Real>& trg_nds) {
  4. static constexpr Integer VecLen = DefaultVecLen<Real>();
  5. using VecType = Vec<Real, VecLen>;
  6. VecType vec_one((Real)1);
  7. const Long Nsrc = src_nds.Dim();
  8. const Long Ntrg = trg_nds.Dim();
  9. const Long Ntrg_ = (Ntrg/VecLen)*VecLen;
  10. if (wts.Dim() != Nsrc*Ntrg) wts.ReInit(Nsrc*Ntrg);
  11. Matrix<Real> M(Nsrc, Ntrg, wts.begin(), false);
  12. StaticArray<Real,50> w_buff;
  13. Vector<Real> w(Nsrc, (Nsrc>=50?NullIterator<Real>():w_buff), (Nsrc>=50));
  14. for (Integer j = 0; j < Nsrc; j++) {
  15. Real w_inv = 1;
  16. Real src_nds_j(src_nds[j]);
  17. for (Integer k = 0; k < j; k++) w_inv *= src_nds[k] - src_nds_j;
  18. for (Integer k = j+1; k < Nsrc; k++) w_inv *= src_nds[k] - src_nds_j;
  19. w[j] = 1/w_inv;
  20. }
  21. if (0) {
  22. for (Long i1 = 0; i1 < Ntrg_; i1+=VecLen) {
  23. VecType x = VecType::Load(&trg_nds[i1]);
  24. for (Integer j = 0; j < Nsrc; j++) {
  25. VecType y0(vec_one);
  26. for (Integer k = 0; k < j; k++) y0 *= VecType(src_nds[k]) - x;
  27. for (Integer k = j+1; k < Nsrc; k++) y0 *= VecType(src_nds[k]) - x;
  28. VecType y = y0 * w[j];
  29. y.Store(&M[j][i1]);
  30. }
  31. }
  32. for (Long i1 = Ntrg_; i1 < Ntrg; i1++) {
  33. Real x = trg_nds[i1];
  34. for (Integer j = 0; j < Nsrc; j++) {
  35. Real y0 = 1;
  36. for (Integer k = 0; k < j; k++) y0 *= src_nds[k] - x;
  37. for (Integer k = j+1; k < Nsrc; k++) y0 *= src_nds[k] - x;
  38. M[j][i1] = y0 * w[j];
  39. }
  40. }
  41. }
  42. if (1) { // Barycentric // TODO: vectorize
  43. //static constexpr Integer digits = (Integer)(TypeTraits<Real>::SigBits*0.3010299957);
  44. for (Long t = 0; t< Ntrg; t++) {
  45. Long s_ = -1;
  46. Real scal = 0;
  47. for (Long s = 0; s < Nsrc; s++) {
  48. if (trg_nds[t] == src_nds[s]) s_ = s;
  49. M[s][t] = w[s] / (trg_nds[t] - src_nds[s]);
  50. scal += M[s][t];
  51. }
  52. if (s_ == -1) {
  53. scal = 1/scal;
  54. for (Long s = 0; s < Nsrc; s++) M[s][t] *= scal;
  55. } else {
  56. for (Long s = 0; s < Nsrc; s++) M[s][t] = 0;
  57. M[s_][t] = 1;
  58. }
  59. }
  60. }
  61. }
  62. template <class Real> void LagrangeInterp<Real>::Derivative(Vector<Real>& df, const Vector<Real>& f, const Vector<Real>& nds) {
  63. Long N = nds.Dim();
  64. Long dof = f.Dim() / N;
  65. SCTL_ASSERT(f.Dim() == N * dof);
  66. if (df.Dim() != N * dof) df.ReInit(N * dof);
  67. if (N*dof == 0) return;
  68. auto dp = [&nds,&N](Real x, Long i) {
  69. Real scal = 1;
  70. for (Long j = 0; j < N; j++) {
  71. if (i!=j) scal *= (nds[i] - nds[j]);
  72. }
  73. scal = 1/scal;
  74. Real wt = 0;
  75. for (Long k = 0; k < N; k++) {
  76. Real wt_ = 1;
  77. if (k!=i) {
  78. for (Long j = 0; j < N; j++) {
  79. if (j!=k && j!=i) wt_ *= (x - nds[j]);
  80. }
  81. wt += wt_;
  82. }
  83. }
  84. return wt * scal;
  85. };
  86. for (Long k = 0; k < dof; k++) {
  87. for (Long i = 0; i < N; i++) {
  88. Real df_ = 0;
  89. for (Long j = 0; j < N; j++) {
  90. df_ += f[k*N+j] * dp(nds[i],j);
  91. }
  92. df[k*N+i] = df_;
  93. }
  94. }
  95. }
  96. template <class Real> void LagrangeInterp<Real>::test() { // TODO: cleanup
  97. Matrix<Real> f(1,3);
  98. f[0][0] = 0; f[0][1] = 1; f[0][2] = 0.5;
  99. Vector<Real> src, trg;
  100. for (Long i = 0; i < 3; i++) src.PushBack(i);
  101. for (Long i = 0; i < 11; i++) trg.PushBack(i*0.2);
  102. Vector<Real> wts;
  103. Interpolate(wts,src,trg);
  104. Matrix<Real> Mwts(src.Dim(), trg.Dim(), wts.begin(), false);
  105. Matrix<Real> ff = f * Mwts;
  106. std::cout<<ff<<'\n';
  107. Vector<Real> df;
  108. Derivative(df, Vector<Real>(f.Dim(0)*f.Dim(1),f.begin()), src);
  109. std::cout<<df<<'\n';
  110. }
  111. }