parallel_solver.hpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. #ifndef _SCTL_PARALLEL_SOLVER_HPP_
  2. #define _SCTL_PARALLEL_SOLVER_HPP_
  3. #include <sctl/common.hpp>
  4. #include SCTL_INCLUDE(comm.hpp)
  5. #include SCTL_INCLUDE(mem_mgr.hpp)
  6. #include SCTL_INCLUDE(math_utils.hpp)
  7. #include <functional>
  8. namespace SCTL_NAMESPACE {
  9. template <class ValueType> class Vector;
  10. template <class ValueType> class Matrix;
  11. template <class Real> class ParallelSolver {
  12. public:
  13. using ParallelOp = std::function<void(Vector<Real>*, const Vector<Real>&)>;
  14. ParallelSolver(const Comm& comm = Comm::Self(), bool verbose = true) : comm_(comm), verbose_(verbose) {}
  15. void operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, const Real tol, const Integer max_iter = -1, const bool use_abs_tol = false);
  16. static void test(Long N = 15) {
  17. srand48(0);
  18. Matrix<Real> A(N, N);
  19. Vector<Real> b(N), x;
  20. for (Long i = 0; i < N; i++) {
  21. b[i] = drand48();
  22. for (Long j = 0; j < N; j++) {
  23. A[i][j] = drand48();
  24. }
  25. }
  26. auto LinOp = [&A](Vector<Real>* Ax, const Vector<Real>& x) {
  27. const Long N = x.Dim();
  28. Ax->ReInit(N);
  29. Matrix<Real> Ax_(N, 1, Ax->begin(), false);
  30. Ax_ = A * Matrix<Real>(N, 1, (Iterator<Real>)x.begin(), false);
  31. };
  32. ParallelSolver<Real> solver;
  33. solver(&x, LinOp, b, 1e-10, -1, false);
  34. auto print_error = [N,&A,&b](const Vector<Real>& x) {
  35. Real max_err = 0;
  36. auto Merr = A*Matrix<Real>(N, 1, (Iterator<Real>)x.begin(), false) - Matrix<Real>(N, 1, b.begin(), false);
  37. for (const auto& a : Merr) max_err = std::max(max_err, fabs(a));
  38. std::cout<<"Maximum error = "<<max_err<<'\n';
  39. };
  40. print_error(x);
  41. }
  42. private:
  43. void GenericGMRES(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, const Real tol, Integer max_iter, const bool use_abs_tol);
  44. Comm comm_;
  45. bool verbose_;
  46. };
  47. } // end namespace
  48. namespace SCTL_NAMESPACE {
  49. template <class Real> static Real inner_prod(const Vector<Real>& x, const Vector<Real>& y, const Comm& comm) {
  50. Real x_dot_y = 0;
  51. Long N = x.Dim();
  52. SCTL_ASSERT(y.Dim() == N);
  53. for (Long i = 0; i < N; i++) x_dot_y += x[i] * y[i];
  54. Real x_dot_y_glb = 0;
  55. comm.Allreduce(Ptr2ConstItr<Real>(&x_dot_y, 1), Ptr2Itr<Real>(&x_dot_y_glb, 1), 1, Comm::CommOp::SUM);
  56. return x_dot_y_glb;
  57. }
  58. template <class Real> inline void ParallelSolver<Real>::GenericGMRES(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter, bool use_abs_tol) {
  59. const Long N = b.Dim();
  60. if (max_iter < 0) { // set max_iter
  61. StaticArray<Long,2> NN{N,0};
  62. comm_.Allreduce(NN+0, NN+1, 1, Comm::CommOp::SUM);
  63. max_iter = NN[1];
  64. }
  65. static constexpr Real ARRAY_RESIZE_FACTOR = 1.618;
  66. Vector<Real> Q_mat, H_mat;
  67. auto ResizeVector = [](Vector<Real>& v, const Long N0) {
  68. if (v.Dim() < N0) {
  69. Vector<Real> v_(N0);
  70. for (Long i = 0; i < v.Dim(); i++) v_[i] = v[i];
  71. for (Long i = v.Dim(); i < N0; i++) v_[i] = 0;
  72. v.Swap(v_);
  73. }
  74. };
  75. auto Q_row = [N,&Q_mat,&ResizeVector](Long i) -> Iterator<Real> {
  76. const Long idx = i*N;
  77. if (Q_mat.Dim() <= idx+N) {
  78. ResizeVector(Q_mat, (Long)((idx+N)*ARRAY_RESIZE_FACTOR));
  79. }
  80. return Q_mat.begin() + idx;
  81. };
  82. auto Q = [&Q_row](Long i, Long j) -> Real& {
  83. return Q_row(i)[j];
  84. };
  85. auto H_row = [&H_mat,&ResizeVector](Long i) -> Iterator<Real> {
  86. const Long idx = i*(i+1)/2;
  87. if (H_mat.Dim() <= idx+i+1) ResizeVector(H_mat, (Long)((idx+i+1)*ARRAY_RESIZE_FACTOR));
  88. return H_mat.begin() + idx;
  89. };
  90. auto H = [&H_row](Long i, Long j) -> Real& {
  91. return H_row(i)[j];
  92. };
  93. auto apply_givens_rotation = [](Vector<Real>& h, Real& cs_k, Real& sn_k, const Vector<Real>& cs, const Vector<Real>& sn, const Long k) {
  94. // apply for ith row
  95. for (Long i = 0; i < k; i++) {
  96. Real temp = cs[i] * h[i] + sn[i] * h[i+1];
  97. h[i+1] = -sn[i] * h[i] + cs[i] * h[i+1];
  98. h[i] = temp;
  99. }
  100. // update the next sin cos values for rotation
  101. const Real t = sqrt<Real>(h[k]*h[k] + h[k+1]*h[k+1]);
  102. cs_k = h[k] / t;
  103. sn_k = h[k+1] / t;
  104. // eliminate H(i + 1, i)
  105. h[k] = cs_k * h[k] + sn_k * h[k+1];
  106. h[k+1] = 0.0;
  107. };
  108. auto arnoldi = [this,N,&Q_row,&Q](Vector<Real>& h, Vector<Real>& q, const ParallelOp& A, const Long k) {
  109. q.ReInit(N); // Krylov Vector
  110. A(&q, Vector<Real>(N, Q_row(k), false));
  111. for (Long i = 0; i < k+1; i++) { // Modified Gram-Schmidt, keeping the Hessenberg matrix
  112. h[i] = inner_prod(q, Vector<Real>(N, Q_row(i), false), comm_);
  113. for (Long j = 0; j < N; j++) {
  114. q[j] -= h[i] * Q(i,j);
  115. }
  116. }
  117. h[k+1] = sqrt<Real>(inner_prod(q, q, comm_));
  118. q *= 1/h[k+1];
  119. };
  120. Vector<Real> r;
  121. if (x->Dim()) { // r = b - A * x;
  122. Vector<Real> Ax;
  123. A(&Ax, *x);
  124. r = b - Ax;
  125. } else {
  126. r = b;
  127. x->ReInit(N);
  128. x->SetZero();
  129. }
  130. const Real b_norm = sqrt<Real>(inner_prod(b, b, comm_));
  131. const Real abs_tol = tol * (use_abs_tol ? 1 : b_norm);
  132. const Real r_norm = sqrt<Real>(inner_prod(r, r, comm_));
  133. for (Long i = 0; i < N; i++) Q(0,i) = r[i] / r_norm;
  134. Vector<Real> beta(1); beta = r_norm;
  135. Vector<Real> sn, cs, h_k, q_k(N);
  136. Long k = 0;
  137. Real error = r_norm;
  138. for (; k < max_iter && error > abs_tol; k++) {
  139. if (verbose_ && !comm_.Rank()) printf("%3lld KSP Residual norm %.12e\n", (long long)k, (double)error);
  140. if (sn.Dim() <= k) ResizeVector(sn, (Long)((k+1)*ARRAY_RESIZE_FACTOR));
  141. if (cs.Dim() <= k) ResizeVector(cs, (Long)((k+1)*ARRAY_RESIZE_FACTOR));
  142. if (beta.Dim() <= k+1) ResizeVector(beta, (Long)((k+2)*ARRAY_RESIZE_FACTOR));
  143. if ( h_k.Dim() <= k+1) ResizeVector( h_k, (Long)((k+2)*ARRAY_RESIZE_FACTOR));
  144. arnoldi(h_k, q_k, A, k);
  145. apply_givens_rotation(h_k, cs[k], sn[k], cs, sn, k); // eliminate the last element in H ith row and update the rotation matrix
  146. for (Long i = 0; i < k+1; i++) H(k,i) = h_k[i];
  147. for (Long i = 0; i < N; i++) Q(k+1,i) = q_k[i];
  148. // update the residual vector
  149. beta[k+1] = -sn[k] * beta[k];
  150. beta[k] = cs[k] * beta[k];
  151. error = fabs(beta[k+1]);
  152. }
  153. if (verbose_ && !comm_.Rank()) printf("%3lld KSP Residual norm %.12e\n", (long long)k, (double)error);
  154. for (Long i = k-1; i >= 0; i--) { // beta <-- beta * inv(H); (through back substitution)
  155. beta[i] /= H(i,i);
  156. for (Long j = 0; j < i; j++) {
  157. beta[j] -= beta[i] * H(i,j);
  158. }
  159. }
  160. for (Long i = 0; i < N; i++) { // x <-- beta * Q
  161. for (Long j = 0; j < k; j++) {
  162. (*x)[i] += beta[j] * Q(j,i);
  163. }
  164. }
  165. }
  166. template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, const Real tol, const Integer max_iter, const bool use_abs_tol) {
  167. GenericGMRES(x, A, b, tol, max_iter, use_abs_tol);
  168. }
  169. } // end namespace
  170. #ifdef SCTL_HAVE_PETSC
  171. #include <petscksp.h>
  172. namespace SCTL_NAMESPACE {
  173. template <class Real> int ParallelSolverMatVec(Mat M_, ::Vec x_, ::Vec Mx_) {
  174. PetscErrorCode ierr;
  175. PetscInt N, N_;
  176. VecGetLocalSize(x_, &N);
  177. VecGetLocalSize(Mx_, &N_);
  178. SCTL_ASSERT(N == N_);
  179. void* data = nullptr;
  180. MatShellGetContext(M_, &data);
  181. auto& M = dynamic_cast<const typename ParallelSolver<Real>::ParallelOp&>(*(typename ParallelSolver<Real>::ParallelOp*)data);
  182. const PetscScalar* x_ptr;
  183. ierr = VecGetArrayRead(x_, &x_ptr);
  184. CHKERRQ(ierr);
  185. Vector<Real> x(N);
  186. for (Long i = 0; i < N; i++) x[i] = (Real)x_ptr[i];
  187. Vector<Real> Mx(N);
  188. M(&Mx, x);
  189. PetscScalar* Mx_ptr;
  190. ierr = VecGetArray(Mx_, &Mx_ptr);
  191. CHKERRQ(ierr);
  192. for (long i = 0; i < N; i++) Mx_ptr[i] = Mx[i];
  193. ierr = VecRestoreArray(Mx_, &Mx_ptr);
  194. CHKERRQ(ierr);
  195. return 0;
  196. }
  197. PetscErrorCode MyKSPMonitor(KSP ksp, PetscInt n, PetscReal rnorm, void *dummy) {
  198. Comm* comm = (Comm*)dummy;
  199. if (!comm->Rank()) printf("%3lld KSP Residual norm %.12e\n", (long long)n, (double)rnorm);
  200. //PetscPrintf(PETSC_COMM_WORLD,"iteration %D KSP Residual norm %14.12e \n",n,rnorm);
  201. //PetscViewerAndFormat *vf;
  202. //PetscViewerAndFormatCreate(PETSC_VIEWER_STDOUT_WORLD, PETSC_VIEWER_DEFAULT, &vf);
  203. //KSPMonitorResidual(ksp, n, rnorm, vf);
  204. //PetscViewerAndFormatDestroy(&vf);
  205. return 0;
  206. }
  207. template <class Real> inline void PETScGMRES(Vector<Real>* x, const typename ParallelSolver<Real>::ParallelOp& A, const Vector<Real>& b, const Real tol, Integer max_iter, const bool use_abs_tol, const bool verbose_, const Comm& comm_) {
  208. PetscInt N = b.Dim();
  209. if (max_iter < 0) { // set max_iter
  210. StaticArray<Long,2> NN{N,0};
  211. comm_.Allreduce(NN+0, NN+1, 1, Comm::CommOp::SUM);
  212. max_iter = NN[1];
  213. }
  214. const MPI_Comm comm = comm_.GetMPI_Comm();
  215. PetscErrorCode ierr;
  216. Mat PetscA;
  217. { // Create Matrix. PetscA
  218. MatCreateShell(comm, N, N, PETSC_DETERMINE, PETSC_DETERMINE, (void*)&A, &PetscA);
  219. MatShellSetOperation(PetscA, MATOP_MULT, (void (*)(void))ParallelSolverMatVec<Real>);
  220. }
  221. ::Vec Petsc_x, Petsc_b;
  222. { // Create vectors
  223. VecCreateMPI(comm, N, PETSC_DETERMINE, &Petsc_b);
  224. VecCreateMPI(comm, N, PETSC_DETERMINE, &Petsc_x);
  225. PetscScalar* b_ptr;
  226. ierr = VecGetArray(Petsc_b, &b_ptr);
  227. CHKERRABORT(comm, ierr);
  228. for (long i = 0; i < N; i++) b_ptr[i] = b[i];
  229. ierr = VecRestoreArray(Petsc_b, &b_ptr);
  230. CHKERRABORT(comm, ierr);
  231. }
  232. // Create linear solver context
  233. KSP ksp;
  234. ierr = KSPCreate(comm, &ksp);
  235. CHKERRABORT(comm, ierr);
  236. // Set operators. Here the matrix that defines the linear system
  237. // also serves as the preconditioning matrix.
  238. ierr = KSPSetOperators(ksp, PetscA, PetscA);
  239. CHKERRABORT(comm, ierr);
  240. // Set runtime options
  241. KSPSetType(ksp, KSPGMRES);
  242. KSPSetNormType(ksp, KSP_NORM_UNPRECONDITIONED);
  243. if (use_abs_tol) KSPSetTolerances(ksp, PETSC_DEFAULT, tol, PETSC_DEFAULT, max_iter);
  244. else KSPSetTolerances(ksp, tol, PETSC_DEFAULT, PETSC_DEFAULT, max_iter);
  245. KSPGMRESSetOrthogonalization(ksp, KSPGMRESModifiedGramSchmidtOrthogonalization);
  246. if (verbose_) KSPMonitorSet(ksp, MyKSPMonitor, (MPI_Comm)&comm_, nullptr);
  247. KSPGMRESSetRestart(ksp, max_iter);
  248. ierr = KSPSetFromOptions(ksp);
  249. CHKERRABORT(comm, ierr);
  250. // -------------------------------------------------------------------
  251. // Solve the linear system: Ax=b
  252. // -------------------------------------------------------------------
  253. ierr = KSPSolve(ksp, Petsc_b, Petsc_x);
  254. CHKERRABORT(comm, ierr);
  255. // View info about the solver
  256. // KSPView(ksp,PETSC_VIEWER_STDOUT_WORLD); CHKERRABORT(comm, ierr);
  257. // Iterations
  258. // PetscInt its;
  259. // ierr = KSPGetIterationNumber(ksp,&its); CHKERRABORT(comm, ierr);
  260. // ierr = PetscPrintf(PETSC_COMM_WORLD,"Iterations %D\n",its); CHKERRABORT(comm, ierr);
  261. { // Set x
  262. const PetscScalar* x_ptr;
  263. ierr = VecGetArrayRead(Petsc_x, &x_ptr);
  264. CHKERRABORT(comm, ierr);
  265. if (x->Dim() != N) x->ReInit(N);
  266. for (long i = 0; i < N; i++) (*x)[i] = (Real)x_ptr[i];
  267. }
  268. ierr = KSPDestroy(&ksp);
  269. CHKERRABORT(comm, ierr);
  270. ierr = MatDestroy(&PetscA);
  271. CHKERRABORT(comm, ierr);
  272. ierr = VecDestroy(&Petsc_x);
  273. CHKERRABORT(comm, ierr);
  274. ierr = VecDestroy(&Petsc_b);
  275. CHKERRABORT(comm, ierr);
  276. }
  277. template <> inline void ParallelSolver<double>::operator()(Vector<double>* x, const ParallelOp& A, const Vector<double>& b, const double tol, const Integer max_iter, const bool use_abs_tol) {
  278. PETScGMRES(x, A, b, tol, max_iter, use_abs_tol, verbose_, comm_);
  279. }
  280. template <> inline void ParallelSolver<float>::operator()(Vector<float>* x, const ParallelOp& A, const Vector<float>& b, const float tol, const Integer max_iter, const bool use_abs_tol) {
  281. PETScGMRES(x, A, b, tol, max_iter, use_abs_tol, verbose_, comm_);
  282. }
  283. } // end namespace
  284. #endif
  285. #endif //_SCTL_PARALLEL_SOLVER_HPP_