Browse Source

Fix bug in SVD.

Dhairya Malhotra 8 years ago
parent
commit
0a08947478
1 changed files with 8 additions and 1 deletions
  1. 8 1
      include/pvfmm/mat_utils.txx

+ 8 - 1
include/pvfmm/mat_utils.txx

@@ -111,6 +111,7 @@ template <> inline void cublasgemm<double>(char TransA, char TransB, int M, int
 
 template <class ValueType> static inline void GivensL(Iterator<ValueType> S_, StaticArray<Long, 2> &dim, Long m, ValueType a, ValueType b) {
   ValueType r = pvfmm::sqrt<ValueType>(a * a + b * b);
+  if (r == 0) return;
   ValueType c = a / r;
   ValueType s = -b / r;
 
@@ -128,6 +129,7 @@ template <class ValueType> static inline void GivensL(Iterator<ValueType> S_, St
 
 template <class ValueType> static inline void GivensR(Iterator<ValueType> S_, StaticArray<Long, 2> &dim, Long m, ValueType a, ValueType b) {
   ValueType r = pvfmm::sqrt<ValueType>(a * a + b * b);
+  if (r == 0) return;
   ValueType c = a / r;
   ValueType s = -b / r;
 
@@ -166,6 +168,7 @@ template <class ValueType> static inline void SVD(StaticArray<Long, 2> &dim, Ite
 
         ValueType alpha = pvfmm::sqrt<ValueType>(1 + x1 * x_inv_norm);
         ValueType beta = x_inv_norm / alpha;
+        if (x_inv_norm == 0) alpha = 0; // nothing to do
 
         house_vec[i] = -alpha;
         for (Long j = i + 1; j < dim[0]; j++) {
@@ -211,6 +214,7 @@ template <class ValueType> static inline void SVD(StaticArray<Long, 2> &dim, Ite
 
         ValueType alpha = pvfmm::sqrt<ValueType>(1 + x1 * x_inv_norm);
         ValueType beta = x_inv_norm / alpha;
+        if (x_inv_norm == 0) alpha = 0; // nothing to do
 
         house_vec[i + 1] = -alpha;
         for (Long j = i + 2; j < dim[1]; j++) {
@@ -267,7 +271,10 @@ template <class ValueType> static inline void SVD(StaticArray<Long, 2> &dim, Ite
 
     ValueType alpha = 0;
     ValueType beta = 0;
-    {  // Compute mu
+    if (n - k0 == 2 && S(k0, k0) == 0 && S(k0 + 1, k0 + 1) == 0) { // Compute mu
+      alpha=0;
+      beta=1;
+    } else {
       StaticArray<ValueType, 2 * 2> C;
       C[0 * 2 + 0] = S(n - 2, n - 2) * S(n - 2, n - 2);
       if (n - k0 > 2) C[0 * 2 + 0] += S(n - 3, n - 2) * S(n - 3, n - 2);