parallel_solver.hpp 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. #ifndef _SCTL_PARALLEL_SOLVER_HPP_
  2. #define _SCTL_PARALLEL_SOLVER_HPP_
  3. #include SCTL_INCLUDE(vector.hpp)
  4. #include SCTL_INCLUDE(comm.hpp)
  5. #include <functional>
  6. namespace SCTL_NAMESPACE {
  7. template <class Real> class ParallelSolver {
  8. public:
  9. using ParallelOp = std::function<void(Vector<Real>*, const Vector<Real>&)>;
  10. ParallelSolver(const Comm& comm = Comm::Self(), bool verbose = true) : comm_(comm), verbose_(verbose) {}
  11. void operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter = -1, bool use_abs_tol = false);
  12. private:
  13. Comm comm_;
  14. bool verbose_;
  15. };
  16. } // end namespace
  17. #ifdef SCTL_HAVE_PETSC
  18. #include <petscksp.h>
  19. namespace SCTL_NAMESPACE {
  20. template <class Real> int ParallelSolverMatVec(Mat M_, Vec x_, Vec Mx_) {
  21. PetscErrorCode ierr;
  22. PetscInt N, N_;
  23. VecGetLocalSize(x_, &N);
  24. VecGetLocalSize(Mx_, &N_);
  25. SCTL_ASSERT(N == N_);
  26. void* data = nullptr;
  27. MatShellGetContext(M_, &data);
  28. auto& M = dynamic_cast<const typename ParallelSolver<Real>::ParallelOp&>(*(typename ParallelSolver<Real>::ParallelOp*)data);
  29. const PetscScalar* x_ptr;
  30. ierr = VecGetArrayRead(x_, &x_ptr);
  31. CHKERRQ(ierr);
  32. Vector<Real> x(N);
  33. for (Long i = 0; i < N; i++) x[i] = x_ptr[i];
  34. Vector<Real> Mx(N);
  35. M(&Mx, x);
  36. PetscScalar* Mx_ptr;
  37. ierr = VecGetArray(Mx_, &Mx_ptr);
  38. CHKERRQ(ierr);
  39. for (long i = 0; i < N; i++) Mx_ptr[i] = Mx[i];
  40. ierr = VecRestoreArray(Mx_, &Mx_ptr);
  41. CHKERRQ(ierr);
  42. return 0;
  43. }
  44. template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter, bool use_abs_tol) {
  45. PetscInt N = b.Dim();
  46. if (max_iter < 0) max_iter = N;
  47. MPI_Comm comm = comm_.GetMPI_Comm();
  48. PetscErrorCode ierr;
  49. Mat PetscA;
  50. { // Create Matrix. PetscA
  51. MatCreateShell(comm, N, N, PETSC_DETERMINE, PETSC_DETERMINE, (void*)&A, &PetscA);
  52. MatShellSetOperation(PetscA, MATOP_MULT, (void (*)(void))ParallelSolverMatVec<Real>);
  53. }
  54. Vec Petsc_x, Petsc_b;
  55. { // Create vectors
  56. VecCreateMPI(comm, N, PETSC_DETERMINE, &Petsc_b);
  57. VecCreateMPI(comm, N, PETSC_DETERMINE, &Petsc_x);
  58. PetscScalar* b_ptr;
  59. ierr = VecGetArray(Petsc_b, &b_ptr);
  60. CHKERRABORT(comm, ierr);
  61. for (long i = 0; i < N; i++) b_ptr[i] = b[i];
  62. ierr = VecRestoreArray(Petsc_b, &b_ptr);
  63. CHKERRABORT(comm, ierr);
  64. }
  65. // Create linear solver context
  66. KSP ksp;
  67. ierr = KSPCreate(comm, &ksp);
  68. CHKERRABORT(comm, ierr);
  69. // Set operators. Here the matrix that defines the linear system
  70. // also serves as the preconditioning matrix.
  71. ierr = KSPSetOperators(ksp, PetscA, PetscA);
  72. CHKERRABORT(comm, ierr);
  73. // Set runtime options
  74. KSPSetType(ksp, KSPGMRES);
  75. KSPSetNormType(ksp, KSP_NORM_UNPRECONDITIONED);
  76. if (use_abs_tol) KSPSetTolerances(ksp, PETSC_DEFAULT, tol, PETSC_DEFAULT, max_iter);
  77. else KSPSetTolerances(ksp, tol, PETSC_DEFAULT, PETSC_DEFAULT, max_iter);
  78. KSPGMRESSetOrthogonalization(ksp, KSPGMRESModifiedGramSchmidtOrthogonalization);
  79. //if (verbose_) KSPMonitorSet(ksp, KSPMonitorDefault, nullptr, nullptr); // Doesn't work for some versions of PETSc!! WTH!!
  80. KSPGMRESSetRestart(ksp, max_iter);
  81. ierr = KSPSetFromOptions(ksp);
  82. CHKERRABORT(comm, ierr);
  83. // -------------------------------------------------------------------
  84. // Solve the linear system: Ax=b
  85. // -------------------------------------------------------------------
  86. ierr = KSPSolve(ksp, Petsc_b, Petsc_x);
  87. CHKERRABORT(comm, ierr);
  88. // View info about the solver
  89. // KSPView(ksp,PETSC_VIEWER_STDOUT_WORLD); CHKERRABORT(comm, ierr);
  90. // Iterations
  91. // PetscInt its;
  92. // ierr = KSPGetIterationNumber(ksp,&its); CHKERRABORT(comm, ierr);
  93. // ierr = PetscPrintf(PETSC_COMM_WORLD,"Iterations %D\n",its); CHKERRABORT(comm, ierr);
  94. { // Set x
  95. const PetscScalar* x_ptr;
  96. ierr = VecGetArrayRead(Petsc_x, &x_ptr);
  97. CHKERRABORT(comm, ierr);
  98. if (x->Dim() != N) x->ReInit(N);
  99. for (long i = 0; i < N; i++) (*x)[i] = x_ptr[i];
  100. }
  101. ierr = KSPDestroy(&ksp);
  102. CHKERRABORT(comm, ierr);
  103. ierr = MatDestroy(&PetscA);
  104. CHKERRABORT(comm, ierr);
  105. ierr = VecDestroy(&Petsc_x);
  106. CHKERRABORT(comm, ierr);
  107. ierr = VecDestroy(&Petsc_b);
  108. CHKERRABORT(comm, ierr);
  109. }
  110. } // end namespace
  111. #else
  112. namespace SCTL_NAMESPACE {
  113. template <class Real> static Real inner_prod(const Vector<Real>& x, const Vector<Real>& y, const Comm& comm) {
  114. Real x_dot_y = 0;
  115. Long N = x.Dim();
  116. SCTL_ASSERT(y.Dim() == N);
  117. for (Long i = 0; i < N; i++) x_dot_y += x[i] * y[i];
  118. Real x_dot_y_glb = 0;
  119. comm.Allreduce(Ptr2ConstItr<Real>(&x_dot_y, 1), Ptr2Itr<Real>(&x_dot_y_glb, 1), 1, Comm::CommOp::SUM);
  120. return x_dot_y_glb;
  121. }
  122. template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter, bool use_abs_tol) {
  123. Long N = b.Dim();
  124. if (max_iter < 0) max_iter = N;
  125. { // Initialize x
  126. if (x->Dim() != N) x->ReInit(N);
  127. x->SetZero();
  128. }
  129. Real b_norm = sqrt(inner_prod(b, b, comm_));
  130. if (b_norm == 0) return;
  131. Vector<Real> q(N);
  132. Vector<Vector<Real>> Q;
  133. Vector<Vector<Real>> H;
  134. { // Initialize q, Q
  135. q = b;
  136. Real one_over_q_norm = 1.0 / sqrt<Real>(inner_prod(q, q, comm_));
  137. for (Long j = 0; j < N; j++) q[j] *= one_over_q_norm;
  138. Q.PushBack(q);
  139. }
  140. Matrix<Real> H_;
  141. Vector<Real> Aq(N), y, h, r = b;
  142. Real abs_tol = tol * (use_abs_tol ? 1 : b_norm);
  143. while (1) {
  144. Real r_norm = sqrt(inner_prod(r, r, comm_));
  145. if (verbose_ && !comm_.Rank()) printf("%3lld KSP Residual norm %.12e\n", (long long)H.Dim(), r_norm);
  146. if (r_norm < abs_tol || H.Dim() == max_iter) break;
  147. A(&Aq, q);
  148. q = Aq;
  149. h.ReInit(Q.Dim() + 1);
  150. for (Integer i = 0; i < Q.Dim(); i++) { // Orthogonalized q
  151. h[i] = inner_prod(q, Q[i], comm_);
  152. for (Long j = 0; j < N; j++) q[j] -= h[i] * Q[i][j];
  153. }
  154. { // Normalize q
  155. h[Q.Dim()] = sqrt<Real>(inner_prod(q, q, comm_));
  156. if (h[Q.Dim()] == 0) break;
  157. Real one_over_q_norm = 1.0 / h[Q.Dim()];
  158. for (Long j = 0; j < N; j++) q[j] *= one_over_q_norm;
  159. }
  160. Q.PushBack(q);
  161. H.PushBack(h);
  162. { // Set y
  163. H_.ReInit(H.Dim(), Q.Dim());
  164. H_.SetZero();
  165. for (Integer i = 0; i < H.Dim(); i++) {
  166. for (Integer j = 0; j < H[i].Dim(); j++) {
  167. H_[i][j] = H[i][j];
  168. }
  169. }
  170. H_ = H_.pinv();
  171. y.ReInit(H_.Dim(1));
  172. for (Integer i = 0; i < y.Dim(); i++) {
  173. y[i] = H_[0][i] * b_norm;
  174. }
  175. }
  176. { // Compute residual
  177. Vector<Real> Hy(Q.Dim());
  178. Hy.SetZero();
  179. for (Integer i = 0; i < H.Dim(); i++) {
  180. for (Integer j = 0; j < H[i].Dim(); j++) {
  181. Hy[j] += H[i][j] * y[i];
  182. }
  183. }
  184. Vector<Real> QHy(N);
  185. QHy.SetZero();
  186. for (Integer i = 0; i < Q.Dim(); i++) {
  187. for (Long j = 0; j < N; j++) {
  188. QHy[j] += Q[i][j] * Hy[i];
  189. }
  190. }
  191. for (Integer j = 0; j < N; j++) { // Set r
  192. r[j] = b[j] - QHy[j];
  193. }
  194. }
  195. }
  196. { // Set x
  197. for (Integer i = 0; i < y.Dim(); i++) {
  198. for (Integer j = 0; j < N; j++) {
  199. (*x)[j] += y[i] * Q[i][j];
  200. }
  201. }
  202. }
  203. }
  204. } // end namespace
  205. #endif
  206. #endif //_SCTL_PARALLEL_SOLVER_HPP_