ode-solver.hpp 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. #ifndef _SCTL_ODE_SOLVER_
  2. #define _SCTL_ODE_SOLVER_
  3. #include <sctl/common.hpp>
  4. #include SCTL_INCLUDE(math_utils.hpp)
  5. #include <functional>
  6. namespace SCTL_NAMESPACE {
  7. template <class ValueType> class Vector;
  8. template <class ValueType> class Matrix;
  9. template <class Real> class SDC {
  10. public:
  11. using Fn = std::function<void(Vector<Real>* dudt, const Vector<Real>& u)>;
  12. using MonitorFn = std::function<void(Real t, Real dt, const Vector<Real>& u)>;
  13. /**
  14. * Constructor
  15. *
  16. * @param[in] Order the order of the method.
  17. */
  18. explicit SDC(const Integer Order, const Comm& comm = Comm::Self());
  19. /**
  20. * Apply one step of spectral deferred correction (SDC).
  21. * Compute: u = u0 + \int_0^{dt} F(u)
  22. *
  23. * @param[out] u the solution
  24. * @param[in] dt the step size
  25. * @param[in] u0 the initial value
  26. * @param[in] F the function du/dt
  27. * @param[in] N_picard the maximum number of picard iterations
  28. * @param[in] tol_picard the tolerance for stopping picard iterations
  29. * @param[out] error_interp an estimate of the truncation error of the solution interpolant
  30. * @param[out] error_picard the picard iteration error
  31. * @param[out] norm_dudt maximum norm of du/dt
  32. */
  33. void operator()(Vector<Real>* u, const Real dt, const Vector<Real>& u0, const Fn& F, Integer N_picard = -1, const Real tol_picard = 0, Real* error_interp = nullptr, Real* error_picard = nullptr, Real* norm_dudt = nullptr) const;
  34. /**
  35. * Solve ODE adaptively to required tolerance.
  36. * Compute: u = u0 + \int_0^{T} F(u)
  37. *
  38. * @param[out] u the final solution
  39. * @param[in] dt the initial step size guess
  40. * @param[in] T the final time
  41. * @param[in] u0 the initial value
  42. * @param[in] F the function du/dt
  43. * @param[in] tol the required solution tolerance
  44. * @param[in] monitor_callback a callback function called after each accepted time-step
  45. * @param[in] continue_with_errors tries to compute the best solution even if the required tolerance cannot be satisfied.
  46. * @param[out] error estimate of the final output error
  47. *
  48. * @return the final time (should equal T if no errors)
  49. */
  50. Real AdaptiveSolve(Vector<Real>* u, Real dt, const Real T, const Vector<Real>& u0, const Fn& F, Real tol, const MonitorFn* monitor_callback = nullptr, bool continue_with_errors = false, Real* error = nullptr) const;
  51. static void test_one_step(const Integer Order = 5) {
  52. auto ref_sol = [](Real t) { return cos<Real>(-t); };
  53. auto fn = [](Vector<Real>* dudt, const Vector<Real>& u) {
  54. (*dudt)[0] = -u[1];
  55. (*dudt)[1] = u[0];
  56. };
  57. std::function<void(Vector<Real>*, const Vector<Real>&)> F(fn);
  58. const SDC<Real> ode_solver(Order);
  59. Real t = 0.0, dt = 1.0e-1;
  60. Vector<Real> u, u0(2);
  61. u0[0] = 1.0;
  62. u0[1] = 0.0;
  63. while (t < 10.0) {
  64. Real error_interp, error_picard;
  65. ode_solver(&u, dt, u0, F, -1, 0.0, &error_interp, &error_picard);
  66. { // Accept solution
  67. u0 = u;
  68. t = t + dt;
  69. }
  70. printf("t = %e; ", t);
  71. printf("u = %e; ", u0[0]);
  72. printf("error = %e; ", ref_sol(t) - u0[0]);
  73. printf("time_step_error_estimate = %e; \n", std::max(error_interp, error_picard));
  74. }
  75. }
  76. static void test_adaptive_solve(const Integer Order = 5, const Real tol = 1e-5) {
  77. auto ref_sol = [](Real t) { return cos(-t); };
  78. auto fn = [](Vector<Real>* dudt, const Vector<Real>& u) {
  79. (*dudt)[0] = -u[1];
  80. (*dudt)[1] = u[0];
  81. };
  82. std::function<void(Vector<Real>*, const Vector<Real>&)> F(fn);
  83. Vector<Real> u, u0(2);
  84. u0[0] = 1.0; u0[1] = 0.0;
  85. Real T = 10.0, dt = 1.0e-1;
  86. SDC<Real> ode_solver(Order);
  87. Real t = ode_solver.AdaptiveSolve(&u, dt, T, u0, F, tol);
  88. if (t == T) {
  89. printf("u = %e; ", u[0]);
  90. printf("error = %e; \n", ref_sol(T) - u[0]);
  91. }
  92. }
  93. private:
  94. Matrix<Real> M_time_step, M_error;
  95. Integer Order;
  96. Comm comm;
  97. };
  98. }
  99. #include SCTL_INCLUDE(ode-solver.txx)
  100. #endif //_SCTL_ODE_SOLVER_