Dhairya Malhotra 7 years ago
parent
commit
ec2ea53d18
1 changed files with 9 additions and 4 deletions
  1. 9 4
      include/sctl/parallel_solver.hpp

+ 9 - 4
include/sctl/parallel_solver.hpp

@@ -18,8 +18,8 @@ template <class Real> class ParallelSolver {
   void operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter = -1);
 
  private:
-  bool verbose_;
   Comm comm_;
+  bool verbose_;
 };
 
 }  // end namespace
@@ -160,7 +160,13 @@ template <class Real> static Real inner_prod(const Vector<Real>& x, const Vector
 template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter) {
   Long N = b.Dim();
   if (max_iter < 0) max_iter = N;
+
+  { // Initialize x
+    if (x->Dim() != N) x->ReInit(N);
+    x->SetZero();
+  }
   Real b_norm = sqrt(inner_prod(b, b, comm_));
+  if (b_norm == 0) return;
 
   Vector<Real> q(N);
   Vector<Vector<Real>> Q;
@@ -176,7 +182,7 @@ template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>*
   Vector<Real> Aq(N), y, h, r = b;
   while (1) {
     Real r_norm = sqrt(inner_prod(r, r, comm_));
-    if (verbose_ && !comm_.Rank()) printf("%3d KSP Residual norm %.12e\n", H.Dim(), r_norm);
+    if (verbose_ && !comm_.Rank()) printf("%3lld KSP Residual norm %.12e\n", (long long)H.Dim(), r_norm);
     if (r_norm < tol * b_norm || H.Dim() == max_iter) break;
 
     A(&Aq, q);
@@ -189,6 +195,7 @@ template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>*
     }
     {  // Normalize q
       h[Q.Dim()] = sqrt<Real>(inner_prod(q, q, comm_));
+      if (h[Q.Dim()] == 0) break;
       Real one_over_q_norm = 1.0 / h[Q.Dim()];
       for (Long j = 0; j < N; j++) q[j] *= one_over_q_norm;
     }
@@ -235,8 +242,6 @@ template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>*
   }
 
   {  // Set x
-    if (x->Dim() != N) x->ReInit(N);
-    x->SetZero();
     for (Integer i = 0; i < y.Dim(); i++) {
       for (Integer j = 0; j < N; j++) {
         (*x)[j] += y[i] * Q[i][j];