ode-solver.hpp 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. #ifndef _SCTL_ODE_SOLVER_
  2. #define _SCTL_ODE_SOLVER_
  3. #include SCTL_INCLUDE(common.hpp)
  4. #include SCTL_INCLUDE(vector.hpp)
  5. #include SCTL_INCLUDE(matrix.hpp)
  6. #include <functional>
  7. namespace SCTL_NAMESPACE {
  8. template <class Real, Integer ORDER> class SDC {
  9. public:
  10. using Fn = std::function<void(Vector<Real>*, const Vector<Real>&)>;
  11. SDC() {
  12. Vector<Real> x_cheb(ORDER);
  13. for (Long i = 0; i < ORDER; i++) {
  14. x_cheb[i] = 0.5 - 0.5 * cos(const_pi<Real>() * i / (ORDER - 1));
  15. }
  16. Matrix<Real> Mp(ORDER, ORDER);
  17. Matrix<Real> Mi(ORDER, ORDER);
  18. for (Long i = 0; i < ORDER; i++) {
  19. for (Long j = 0; j < ORDER; j++) {
  20. Mp[j][i] = pow<Real>(x_cheb[i],j);
  21. Mi[j][i] = pow<Real>(x_cheb[i],j+1) / (j+1);
  22. }
  23. }
  24. M_time_step = (Mp.pinv() * Mi).Transpose(); // TODO: replace Mp.pinv()
  25. Mp.ReInit(ORDER,ORDER); Mp = 0;
  26. Mi.ReInit(ORDER,ORDER); Mi = 0;
  27. Integer TRUNC_ORDER = ORDER;
  28. if (ORDER >= 2) TRUNC_ORDER = ORDER - 1;
  29. if (ORDER >= 6) TRUNC_ORDER = ORDER - 1;
  30. if (ORDER >= 9) TRUNC_ORDER = ORDER - 1;
  31. for (Long j = 0; j < TRUNC_ORDER; j++) {
  32. for (Long i = 0; i < ORDER; i++) {
  33. Mp[j][i] = pow<Real>(x_cheb[i],j);
  34. Mi[j][i] = pow<Real>(x_cheb[i],j);
  35. }
  36. }
  37. M_error = (Mp.pinv() * Mi).Transpose(); // TODO: replace Mp.pinv()
  38. for (Long i = 0; i < ORDER; i++) M_error[i][i] -= 1;
  39. }
  40. // solve u = \int_0^{dt} F(u)
  41. void operator()(Vector<Real>* u, const Real dt, const Vector<Real>& u0_, const Fn& F, Integer N_picard = ORDER, Real tol_picard = 0, Real* error_interp = nullptr, Real* error_picard = nullptr) {
  42. auto max_norm = [] (const Matrix<Real>& M) {
  43. Real max_val = 0;
  44. for (Long i = 0; i < M.Dim(0); i++) {
  45. for (Long j = 0; j < M.Dim(1); j++) {
  46. max_val = std::max<Real>(max_val, fabs(M[i][j]));
  47. }
  48. }
  49. return max_val;
  50. };
  51. const Long DOF = u0_.Dim();
  52. Matrix<Real> Mu0(ORDER, DOF);
  53. Matrix<Real> Mu1(ORDER, DOF);
  54. for (Long j = 0; j < ORDER; j++) { // Set u0
  55. for (Long k = 0; k < DOF; k++) {
  56. Mu0[j][k] = u0_[k];
  57. }
  58. }
  59. Matrix<Real> M_dudt(ORDER, DOF);
  60. { // Set M_dudt
  61. Vector<Real> dudt_(DOF, M_dudt[0], false);
  62. F(&dudt_, Vector<Real>(DOF, Mu0[0], false));
  63. for (Long i = 1; i < ORDER; i++) {
  64. for (Long j = 0; j < DOF; j++) {
  65. M_dudt[i][j] = M_dudt[0][j];
  66. }
  67. }
  68. }
  69. Mu1 = Mu0 + (M_time_step * M_dudt) * dt;
  70. Matrix<Real> Merr(ORDER, DOF);
  71. for (Long k = 0; k < N_picard; k++) { // Picard iteration
  72. auto Mu_previous = Mu1;
  73. for (Long i = 1; i < ORDER; i++) { // Set M_dudt
  74. Vector<Real> dudt_(DOF, M_dudt[i], false);
  75. F(&dudt_, Vector<Real>(DOF, Mu1[i], false));
  76. }
  77. Mu1 = Mu0 + (M_time_step * M_dudt) * dt;
  78. Merr = Mu1 - Mu_previous;
  79. if (max_norm(Merr) < tol_picard) break;
  80. }
  81. if (u->Dim() != DOF) u->ReInit(DOF);
  82. for (Long k = 0; k < DOF; k++) { // Set u
  83. u[0][k] = Mu1[ORDER - 1][k];
  84. }
  85. if (error_picard != nullptr) {
  86. error_picard[0] = max_norm(Merr);
  87. }
  88. if (error_interp != nullptr) {
  89. Merr = M_error * Mu1;
  90. error_interp[0] = max_norm(Merr);
  91. }
  92. }
  93. static void test() {
  94. auto ref_sol = [](Real t) { return cos(-t); };
  95. auto fn = [](sctl::Vector<Real>* dudt, const sctl::Vector<Real>& u) {
  96. (*dudt)[0] = -u[1];
  97. (*dudt)[1] = u[0];
  98. };
  99. std::function<void(sctl::Vector<Real>*, const sctl::Vector<Real>&)> F(fn);
  100. sctl::SDC<Real, ORDER> ode_solver;
  101. Real t = 0.0, dt = 1.0e-1;
  102. sctl::Vector<Real> u, u0(2);
  103. u0[0] = 1.0;
  104. u0[1] = 0.0;
  105. while (t < 10.0) {
  106. Real error_interp, error_picard;
  107. ode_solver(&u, dt, u0, F, ORDER, 0.0, &error_interp, &error_picard);
  108. { // Accept solution
  109. u0 = u;
  110. t = t + dt;
  111. }
  112. printf("t = %e; ", t);
  113. printf("u1 = %e; ", u0[0]);
  114. printf("u_ref = %e; ", ref_sol(t));
  115. printf("error = %e; ", ref_sol(t) - u0[0]);
  116. printf("time_step_error_estimate = %e; \n", std::max(error_interp, error_picard));
  117. }
  118. }
  119. private:
  120. Matrix<Real> M_time_step, M_error;
  121. };
  122. }
  123. #endif //_SCTL_ODE_SOLVER_