ode-solver.txx 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. #include SCTL_INCLUDE(lagrange-interp.hpp)
  2. namespace SCTL_NAMESPACE {
  3. template <class Real> SDC<Real>::SDC(const Integer Order_, const Comm& comm_) : Order(Order_), comm(comm_) {
  4. #ifdef SCTL_QUAD_T
  5. using ValueType = QuadReal;
  6. #else
  7. using ValueType = long double;
  8. #endif
  9. auto second_kind_cheb_nds = [](const Integer Order) {
  10. Vector<ValueType> x_cheb(Order);
  11. for (Long i = 0; i < Order; i++) {
  12. x_cheb[i] = 0.5 - 0.5 * cos(const_pi<ValueType>() * i / (Order - 1));
  13. }
  14. return x_cheb;
  15. };
  16. const auto nds0 = second_kind_cheb_nds(Order);
  17. SCTL_ASSERT(nds0.Dim() == Order);
  18. { // Set M_error
  19. Integer TRUNC_Order = Order;
  20. if (Order >= 2) TRUNC_Order = Order - 1;
  21. if (Order >= 6) TRUNC_Order = Order - 1;
  22. if (Order >= 9) TRUNC_Order = Order - 1;
  23. const auto nds1 = second_kind_cheb_nds(TRUNC_Order);
  24. SCTL_ASSERT(nds1.Dim() == TRUNC_Order);
  25. Matrix<ValueType> Minterp0(Order, TRUNC_Order);
  26. Matrix<ValueType> Minterp1(TRUNC_Order, Order);
  27. Vector<ValueType> interp0(Order*TRUNC_Order, Minterp0.begin(), false);
  28. Vector<ValueType> interp1(TRUNC_Order*Order, Minterp1.begin(), false);
  29. LagrangeInterp<ValueType>::Interpolate(interp0, nds0, nds1);
  30. LagrangeInterp<ValueType>::Interpolate(interp1, nds1, nds0);
  31. Matrix<ValueType> M_error_ = (Minterp0 * Minterp1).Transpose();
  32. for (Long i = 0; i < Order; i++) M_error_[i][i] -= 1;
  33. M_error.ReInit(Order, Order);
  34. for (Long i = 0; i < Order*Order; i++) M_error[0][i] = (Real)M_error_[0][i];
  35. }
  36. { // Set M_time_step
  37. const auto qx = ChebQuadRule<ValueType>::ComputeNds(Order);
  38. const auto qw = ChebQuadRule<ValueType>::ComputeWts(Order);
  39. const Matrix<ValueType> Mw(Order, 1, (Iterator<ValueType>)qw.begin(), false);
  40. SCTL_ASSERT(qw.Dim() == Order);
  41. SCTL_ASSERT(qx.Dim() == Order);
  42. Matrix<ValueType> Minterp(Order, Order), M_time_step_(Order, Order);
  43. Vector<ValueType> interp(Order*Order, Minterp.begin(), false);
  44. for (Integer i = 0; i < Order; i++) {
  45. LagrangeInterp<ValueType>::Interpolate(interp, nds0, qx*nds0[i]);
  46. Matrix<ValueType> M_time_step_i(Order,1, M_time_step_[i], false);
  47. M_time_step_i = Minterp * Mw * nds0[i];
  48. }
  49. M_time_step.ReInit(Order, Order);
  50. for (Long i = 0; i < Order*Order; i++) M_time_step[0][i] = (Real)M_time_step_[0][i];
  51. }
  52. }
  53. // solve u = u0 + \int_0^{dt} F(u)
  54. template <class Real> void SDC<Real>::operator()(Vector<Real>* u, const Real dt, const Vector<Real>& u0, const Fn& F, Integer N_picard, const Real tol_picard, Real* error_interp, Real* error_picard, Real* norm_dudt) const {
  55. auto max_norm = [] (const Matrix<Real>& M, const Comm& comm) {
  56. StaticArray<Real,2> max_val{0,0};
  57. for (Long i = 0; i < M.Dim(0); i++) {
  58. for (Long j = 0; j < M.Dim(1); j++) {
  59. max_val[0] = std::max<Real>(max_val[0], fabs(M[i][j]));
  60. }
  61. }
  62. comm.Allreduce((ConstIterator<Real>)max_val, (Iterator<Real>)max_val+1, 1, Comm::CommOp::MAX);
  63. return max_val[1];
  64. };
  65. if (N_picard < 0) N_picard = Order;
  66. const Long DOF = u0.Dim();
  67. Matrix<Real> Mu0(Order, DOF);
  68. Matrix<Real> Mu1(Order, DOF);
  69. for (Long j = 0; j < Order; j++) { // Set Mu0
  70. for (Long k = 0; k < DOF; k++) {
  71. Mu0[j][k] = u0[k];
  72. }
  73. }
  74. Matrix<Real> M_dudt(Order, DOF);
  75. { // Set M_dudt
  76. Vector<Real> dudt_(DOF, M_dudt[0], false);
  77. F(&dudt_, Vector<Real>(DOF, Mu0[0], false));
  78. for (Long i = 1; i < Order; i++) {
  79. for (Long j = 0; j < DOF; j++) {
  80. M_dudt[i][j] = M_dudt[0][j];
  81. }
  82. }
  83. }
  84. Matrix<Real> Mv = (M_time_step * M_dudt) * dt;
  85. Mu1 = Mu0 + Mv;
  86. Real picard_err_curr = 0;
  87. for (Long k = 0; k < N_picard; k++) { // Picard iteration
  88. auto Mv_previous = Mv;
  89. for (Long i = 1; i < Order; i++) { // Set M_dudt
  90. Vector<Real> dudt_(DOF, M_dudt[i], false);
  91. F(&dudt_, Vector<Real>(DOF, Mu1[i], false));
  92. }
  93. Mv = (M_time_step * M_dudt) * dt;
  94. Mu1 = Mu0 + Mv;
  95. picard_err_curr = max_norm(Mv - Mv_previous, comm);
  96. if (picard_err_curr < tol_picard) break;
  97. }
  98. if (u->Dim() != DOF) u->ReInit(DOF);
  99. for (Long k = 0; k < DOF; k++) { // Set u
  100. (*u)[k] = Mu1[Order - 1][k];
  101. }
  102. if (error_picard != nullptr) {
  103. (*error_picard) = picard_err_curr;
  104. }
  105. if (error_interp != nullptr) {
  106. (*error_interp) = max_norm(M_error * Mv, comm);
  107. }
  108. if (norm_dudt != nullptr) {
  109. (*norm_dudt) = max_norm(Mv, comm);
  110. }
  111. }
  112. template <class Real> Real SDC<Real>::AdaptiveSolve(Vector<Real>* u, Real dt, const Real T, const Vector<Real>& u0, const Fn& F, const Real tol, const MonitorFn* monitor_callback, bool continue_with_errors, Real* error) const {
  113. const Real eps = machine_eps<Real>();
  114. Vector<Real> u_, u0_ = u0;
  115. Real t = 0;
  116. Real error_ = 0;
  117. while (t < T && dt > eps*T) {
  118. Real error_interp, error_picard, norm_dudt;
  119. (*this)(&u_, dt, u0_, F, 2*Order, tol*dt*pow<Real>(0.9,Order), &error_interp, &error_picard, &norm_dudt);
  120. Real tol_ = std::max<Real>(tol/T, (tol-error_)/(T-t));
  121. Real max_err = std::max<Real>(error_interp, error_picard);
  122. //std::cout<<t<<' '<<dt<<' '<<error_interp/dt<<' '<<error_picard/dt<<' '<<max_err/norm_dudt/eps<<'\n';
  123. if (max_err < tol_*dt || (continue_with_errors && max_err/norm_dudt < 2*eps)) { // Accept solution
  124. u0_.Swap(u_);
  125. t = t + dt;
  126. error_ += max_err;
  127. if (monitor_callback) (*monitor_callback)(t, dt, u0_);
  128. }
  129. if (continue_with_errors && max_err/norm_dudt < 2*eps) {
  130. dt = std::min<Real>(T-t, 1.1*dt);
  131. } else {
  132. // Adjust time-step size (Quaife, Biros - JCP 2016)
  133. dt = std::min<Real>(T-t, std::max<Real>(0.5*dt, 0.9*dt*pow<Real>((tol_*dt)/max_err, 1/(Real)(Order))));
  134. }
  135. }
  136. if (t < T || error_ > tol) SCTL_WARN("Could not solve ODE to the requested tolerance.");
  137. if (error != nullptr) (*error) = error_;
  138. (*u) = u0_;
  139. return t;
  140. }
  141. }