parallel_solver.hpp 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. #ifndef _SCTL_PARALLEL_SOLVER_HPP_
  2. #define _SCTL_PARALLEL_SOLVER_HPP_
  3. #include SCTL_INCLUDE(math_utils.hpp)
  4. #include SCTL_INCLUDE(matrix.hpp)
  5. #include SCTL_INCLUDE(vector.hpp)
  6. #include SCTL_INCLUDE(comm.hpp)
  7. #include <functional>
  8. namespace SCTL_NAMESPACE {
  9. template <class Real> class ParallelSolver {
  10. public:
  11. using ParallelOp = std::function<void(Vector<Real>*, const Vector<Real>&)>;
  12. ParallelSolver(const Comm& comm = Comm::Self(), bool verbose = true) : comm_(comm), verbose_(verbose) {}
  13. void operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter = -1, bool use_abs_tol = false);
  14. private:
  15. Comm comm_;
  16. bool verbose_;
  17. };
  18. } // end namespace
  19. #ifdef SCTL_HAVE_PETSC
  20. #include <petscksp.h>
  21. namespace SCTL_NAMESPACE {
  22. template <class Real> int ParallelSolverMatVec(Mat M_, Vec x_, Vec Mx_) {
  23. PetscErrorCode ierr;
  24. PetscInt N, N_;
  25. VecGetLocalSize(x_, &N);
  26. VecGetLocalSize(Mx_, &N_);
  27. SCTL_ASSERT(N == N_);
  28. void* data = nullptr;
  29. MatShellGetContext(M_, &data);
  30. auto& M = dynamic_cast<const typename ParallelSolver<Real>::ParallelOp&>(*(typename ParallelSolver<Real>::ParallelOp*)data);
  31. const PetscScalar* x_ptr;
  32. ierr = VecGetArrayRead(x_, &x_ptr);
  33. CHKERRQ(ierr);
  34. Vector<Real> x(N);
  35. for (Long i = 0; i < N; i++) x[i] = x_ptr[i];
  36. Vector<Real> Mx(N);
  37. M(&Mx, x);
  38. PetscScalar* Mx_ptr;
  39. ierr = VecGetArray(Mx_, &Mx_ptr);
  40. CHKERRQ(ierr);
  41. for (long i = 0; i < N; i++) Mx_ptr[i] = Mx[i];
  42. ierr = VecRestoreArray(Mx_, &Mx_ptr);
  43. CHKERRQ(ierr);
  44. return 0;
  45. }
  46. template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter, bool use_abs_tol) {
  47. PetscInt N = b.Dim();
  48. if (max_iter < 0) max_iter = N;
  49. MPI_Comm comm = comm_.GetMPI_Comm();
  50. PetscErrorCode ierr;
  51. Mat PetscA;
  52. { // Create Matrix. PetscA
  53. MatCreateShell(comm, N, N, PETSC_DETERMINE, PETSC_DETERMINE, (void*)&A, &PetscA);
  54. MatShellSetOperation(PetscA, MATOP_MULT, (void (*)(void))ParallelSolverMatVec<Real>);
  55. }
  56. Vec Petsc_x, Petsc_b;
  57. { // Create vectors
  58. VecCreateMPI(comm, N, PETSC_DETERMINE, &Petsc_b);
  59. VecCreateMPI(comm, N, PETSC_DETERMINE, &Petsc_x);
  60. PetscScalar* b_ptr;
  61. ierr = VecGetArray(Petsc_b, &b_ptr);
  62. CHKERRABORT(comm, ierr);
  63. for (long i = 0; i < N; i++) b_ptr[i] = b[i];
  64. ierr = VecRestoreArray(Petsc_b, &b_ptr);
  65. CHKERRABORT(comm, ierr);
  66. }
  67. // Create linear solver context
  68. KSP ksp;
  69. ierr = KSPCreate(comm, &ksp);
  70. CHKERRABORT(comm, ierr);
  71. // Set operators. Here the matrix that defines the linear system
  72. // also serves as the preconditioning matrix.
  73. ierr = KSPSetOperators(ksp, PetscA, PetscA);
  74. CHKERRABORT(comm, ierr);
  75. // Set runtime options
  76. KSPSetType(ksp, KSPGMRES);
  77. KSPSetNormType(ksp, KSP_NORM_UNPRECONDITIONED);
  78. if (use_abs_tol) KSPSetTolerances(ksp, PETSC_DEFAULT, tol, PETSC_DEFAULT, max_iter);
  79. else KSPSetTolerances(ksp, tol, PETSC_DEFAULT, PETSC_DEFAULT, max_iter);
  80. KSPGMRESSetOrthogonalization(ksp, KSPGMRESModifiedGramSchmidtOrthogonalization);
  81. //if (verbose_) KSPMonitorSet(ksp, KSPMonitorDefault, nullptr, nullptr); // Doesn't work for some versions of PETSc!! WTH!!
  82. KSPGMRESSetRestart(ksp, max_iter);
  83. ierr = KSPSetFromOptions(ksp);
  84. CHKERRABORT(comm, ierr);
  85. // -------------------------------------------------------------------
  86. // Solve the linear system: Ax=b
  87. // -------------------------------------------------------------------
  88. ierr = KSPSolve(ksp, Petsc_b, Petsc_x);
  89. CHKERRABORT(comm, ierr);
  90. // View info about the solver
  91. // KSPView(ksp,PETSC_VIEWER_STDOUT_WORLD); CHKERRABORT(comm, ierr);
  92. // Iterations
  93. // PetscInt its;
  94. // ierr = KSPGetIterationNumber(ksp,&its); CHKERRABORT(comm, ierr);
  95. // ierr = PetscPrintf(PETSC_COMM_WORLD,"Iterations %D\n",its); CHKERRABORT(comm, ierr);
  96. { // Set x
  97. const PetscScalar* x_ptr;
  98. ierr = VecGetArrayRead(Petsc_x, &x_ptr);
  99. CHKERRABORT(comm, ierr);
  100. if (x->Dim() != N) x->ReInit(N);
  101. for (long i = 0; i < N; i++) (*x)[i] = x_ptr[i];
  102. }
  103. ierr = KSPDestroy(&ksp);
  104. CHKERRABORT(comm, ierr);
  105. ierr = MatDestroy(&PetscA);
  106. CHKERRABORT(comm, ierr);
  107. ierr = VecDestroy(&Petsc_x);
  108. CHKERRABORT(comm, ierr);
  109. ierr = VecDestroy(&Petsc_b);
  110. CHKERRABORT(comm, ierr);
  111. }
  112. } // end namespace
  113. #else
  114. namespace SCTL_NAMESPACE {
  115. template <class Real> static Real inner_prod(const Vector<Real>& x, const Vector<Real>& y, const Comm& comm) {
  116. Real x_dot_y = 0;
  117. Long N = x.Dim();
  118. SCTL_ASSERT(y.Dim() == N);
  119. for (Long i = 0; i < N; i++) x_dot_y += x[i] * y[i];
  120. Real x_dot_y_glb = 0;
  121. comm.Allreduce(Ptr2ConstItr<Real>(&x_dot_y, 1), Ptr2Itr<Real>(&x_dot_y_glb, 1), 1, Comm::CommOp::SUM);
  122. return x_dot_y_glb;
  123. }
  124. template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter, bool use_abs_tol) {
  125. Long N = b.Dim();
  126. if (max_iter < 0) max_iter = N;
  127. { // Initialize x
  128. if (x->Dim() != N) x->ReInit(N);
  129. x->SetZero();
  130. }
  131. Real b_norm = sqrt(inner_prod(b, b, comm_));
  132. if (b_norm == 0) return;
  133. Vector<Real> q(N);
  134. Vector<Vector<Real>> Q;
  135. Vector<Vector<Real>> H;
  136. { // Initialize q, Q
  137. q = b;
  138. Real one_over_q_norm = 1.0 / sqrt<Real>(inner_prod(q, q, comm_));
  139. for (Long j = 0; j < N; j++) q[j] *= one_over_q_norm;
  140. Q.PushBack(q);
  141. }
  142. Matrix<Real> H_;
  143. Vector<Real> Aq(N), y, h, r = b;
  144. Real abs_tol = tol * (use_abs_tol ? 1 : b_norm);
  145. while (1) {
  146. Real r_norm = sqrt(inner_prod(r, r, comm_));
  147. if (verbose_ && !comm_.Rank()) printf("%3lld KSP Residual norm %.12e\n", (long long)H.Dim(), r_norm);
  148. if (r_norm < abs_tol || H.Dim() == max_iter) break;
  149. A(&Aq, q);
  150. q = Aq;
  151. h.ReInit(Q.Dim() + 1);
  152. for (Integer i = 0; i < Q.Dim(); i++) { // Orthogonalized q
  153. h[i] = inner_prod(q, Q[i], comm_);
  154. for (Long j = 0; j < N; j++) q[j] -= h[i] * Q[i][j];
  155. }
  156. { // Normalize q
  157. h[Q.Dim()] = sqrt<Real>(inner_prod(q, q, comm_));
  158. if (h[Q.Dim()] == 0) break;
  159. Real one_over_q_norm = 1.0 / h[Q.Dim()];
  160. for (Long j = 0; j < N; j++) q[j] *= one_over_q_norm;
  161. }
  162. Q.PushBack(q);
  163. H.PushBack(h);
  164. { // Set y
  165. H_.ReInit(H.Dim(), Q.Dim());
  166. H_.SetZero();
  167. for (Integer i = 0; i < H.Dim(); i++) {
  168. for (Integer j = 0; j < H[i].Dim(); j++) {
  169. H_[i][j] = H[i][j];
  170. }
  171. }
  172. H_ = H_.pinv(0);
  173. y.ReInit(H_.Dim(1));
  174. for (Integer i = 0; i < y.Dim(); i++) {
  175. y[i] = H_[0][i] * b_norm;
  176. }
  177. }
  178. { // Compute residual
  179. Vector<Real> Hy(Q.Dim());
  180. Hy.SetZero();
  181. for (Integer i = 0; i < H.Dim(); i++) {
  182. for (Integer j = 0; j < H[i].Dim(); j++) {
  183. Hy[j] += H[i][j] * y[i];
  184. }
  185. }
  186. Vector<Real> QHy(N);
  187. QHy.SetZero();
  188. for (Integer i = 0; i < Q.Dim(); i++) {
  189. for (Long j = 0; j < N; j++) {
  190. QHy[j] += Q[i][j] * Hy[i];
  191. }
  192. }
  193. for (Integer j = 0; j < N; j++) { // Set r
  194. r[j] = b[j] - QHy[j];
  195. }
  196. }
  197. }
  198. { // Set x
  199. for (Integer i = 0; i < y.Dim(); i++) {
  200. for (Integer j = 0; j < N; j++) {
  201. (*x)[j] += y[i] * Q[i][j];
  202. }
  203. }
  204. }
  205. }
  206. } // end namespace
  207. #endif
  208. #endif //_SCTL_PARALLEL_SOLVER_HPP_