Dhairya Malhotra 7 gadi atpakaļ
vecāks
revīzija
cb9a61e97d

+ 89 - 28
include/sctl/fft_wrapper.hpp

@@ -22,59 +22,102 @@ template <class ValueType> class Complex {
   public:
     Complex<ValueType>(ValueType r=0, ValueType i=0) : real(r), imag(i) {}
 
-    Complex<ValueType> operator*(const Complex<ValueType>& x) const {
+    Complex<ValueType> operator-() const {
       Complex<ValueType> z;
-      z.real = real * x.real - imag * x.imag;
-      z.imag = imag * x.real + real * x.imag;
+      z.real = -real;
+      z.imag = -imag;
       return z;
     }
 
-    Complex<ValueType> operator*(const ValueType& x) const {
+    Complex<ValueType> conj() const {
       Complex<ValueType> z;
-      z.real = real * x;
-      z.imag = imag * x;
+      z.real = real;
+      z.imag = -imag;
       return z;
     }
 
-    Complex<ValueType> operator+(const Complex<ValueType>& x) const {
+
+    bool operator==(const Complex<ValueType>& x) const {
+      return real == x.real && imag == x.imag;
+    }
+
+    bool operator!=(const Complex<ValueType>& x) const {
+      return !((*this) == x);;
+    }
+
+
+    template <class ScalarType> void operator+=(const Complex<ScalarType>& x) {
+      (*this) = (*this) + x;
+    }
+
+    template <class ScalarType> void operator-=(const Complex<ScalarType>& x) {
+      (*this) = (*this) - x;
+    }
+
+    template <class ScalarType> void operator*=(const Complex<ScalarType>& x) {
+      (*this) = (*this) * x;
+    }
+
+    template <class ScalarType> void operator/=(const Complex<ScalarType>& x) {
+      (*this) = (*this) / x;
+    }
+
+
+    template <class ScalarType> Complex<ValueType> operator+(const ScalarType& x) const {
       Complex<ValueType> z;
-      z.real = real + x.real;
-      z.imag = imag + x.imag;
+      z.real = real + x;
+      z.imag = imag;
       return z;
     }
 
-    Complex<ValueType> operator+(const ValueType& x) const {
+    template <class ScalarType> Complex<ValueType> operator-(const ScalarType& x) const {
       Complex<ValueType> z;
-      z.real = real + x;
+      z.real = real - x;
       z.imag = imag;
       return z;
     }
 
-    Complex<ValueType> operator-(const Complex<ValueType>& x) const {
+    template <class ScalarType> Complex<ValueType> operator*(const ScalarType& x) const {
       Complex<ValueType> z;
-      z.real = real - x.real;
-      z.imag = imag - x.imag;
+      z.real = real * x;
+      z.imag = imag * x;
       return z;
     }
 
-    Complex<ValueType> operator-(const ValueType& x) const {
+    template <class ScalarType> Complex<ValueType> operator/(const ScalarType& y) const {
       Complex<ValueType> z;
-      z.real = real - x;
-      z.imag = imag;
+      z.real = real / y;
+      z.imag = imag / y;
       return z;
     }
 
-    Complex<ValueType> operator-() const {
+
+    Complex<ValueType> operator+(const Complex<ValueType>& x) const {
       Complex<ValueType> z;
-      z.real = -real;
-      z.imag = -imag;
+      z.real = real + x.real;
+      z.imag = imag + x.imag;
       return z;
     }
 
-    Complex<ValueType> conj() const {
+    Complex<ValueType> operator-(const Complex<ValueType>& x) const {
       Complex<ValueType> z;
-      z.real = real;
-      z.imag = -imag;
+      z.real = real - x.real;
+      z.imag = imag - x.imag;
+      return z;
+    }
+
+    Complex<ValueType> operator*(const Complex<ValueType>& x) const {
+      Complex<ValueType> z;
+      z.real = real * x.real - imag * x.imag;
+      z.imag = imag * x.real + real * x.imag;
+      return z;
+    }
+
+    Complex<ValueType> operator/(const Complex<ValueType>& y) const {
+      Complex<ValueType> z;
+      ValueType y_inv = 1 / (y.real * y.real + y.imag * y.imag);
+      z.real = (y.real * real + y.imag * imag) * y_inv;
+      z.imag = (y.real * imag - y.imag * real) * y_inv;
       return z;
     }
 
@@ -82,27 +125,35 @@ template <class ValueType> class Complex {
     ValueType imag;
 };
 
-template <class ValueType> Complex<ValueType> operator*(const ValueType& x, const Complex<ValueType>& y){
+template <class ScalarType, class ValueType> Complex<ValueType> operator*(const ScalarType& x, const Complex<ValueType>& y) {
   Complex<ValueType> z;
   z.real = y.real * x;
   z.imag = y.imag * x;
   return z;
 }
 
-template <class ValueType> Complex<ValueType> operator+(const ValueType& x, const Complex<ValueType>& y){
+template <class ScalarType, class ValueType> Complex<ValueType> operator+(const ScalarType& x, const Complex<ValueType>& y) {
   Complex<ValueType> z;
   z.real = y.real + x;
   z.imag = y.imag;
   return z;
 }
 
-template <class ValueType> Complex<ValueType> operator-(const ValueType& x, const Complex<ValueType>& y){
+template <class ScalarType, class ValueType> Complex<ValueType> operator-(const ScalarType& x, const Complex<ValueType>& y) {
   Complex<ValueType> z;
   z.real = y.real - x;
   z.imag = y.imag;
   return z;
 }
 
+template <class ScalarType, class ValueType> Complex<ValueType> operator/(const ScalarType& x, const Complex<ValueType>& y) {
+  Complex<ValueType> z;
+  ValueType y_inv = 1 / (y.real * y.real + y.imag * y.imag);
+  z.real =  (y.real * x) * y_inv;
+  z.imag = -(y.imag * x) * y_inv;
+  return z;
+}
+
 enum class FFT_Type {R2C, C2C, C2C_INV, C2R};
 
 template <class ValueType, class FFT_Derived> class FFT_Generic {
@@ -307,8 +358,9 @@ template <class ValueType, class FFT_Derived> class FFT_Generic {
   }
 
   static void check_align(const Vector<ValueType>& in, const Vector<ValueType>& out) {
-    SCTL_ASSERT_MSG((((uintptr_t)& in[0]) & ((uintptr_t)(SCTL_MEM_ALIGN - 1))) == 0, "sctl::FFT: Input vector not aligned to " <<SCTL_MEM_ALIGN<<" bits!");
-    SCTL_ASSERT_MSG((((uintptr_t)&out[0]) & ((uintptr_t)(SCTL_MEM_ALIGN - 1))) == 0, "sctl::FFT: Output vector not aligned to "<<SCTL_MEM_ALIGN<<" bits!");
+    //SCTL_ASSERT_MSG((((uintptr_t)& in[0]) & ((uintptr_t)(SCTL_MEM_ALIGN - 1))) == 0, "sctl::FFT: Input vector not aligned to " <<SCTL_MEM_ALIGN<<" bits!");
+    //SCTL_ASSERT_MSG((((uintptr_t)&out[0]) & ((uintptr_t)(SCTL_MEM_ALIGN - 1))) == 0, "sctl::FFT: Output vector not aligned to "<<SCTL_MEM_ALIGN<<" bits!");
+    // TODO: copy to auxiliary array if unaligned
   }
 
   StaticArray<Long,2> dim;
@@ -356,6 +408,9 @@ template <> class FFT<double> : public FFT_Generic<double, FFT<double>> {
       } else if (fft_type == FFT_Type::C2R) {
         N0 = (N / dim_vec[rank - 1]) * (dim_vec[rank - 1] / 2 + 1) * 2;
         N1 = N;
+      } else {
+        N0 = 0;
+        N1 = 0;
       }
       this->dim[0] = N0;
       this->dim[1] = N1;
@@ -464,6 +519,9 @@ template <> class FFT<float> : public FFT_Generic<float, FFT<float>> {
       } else if (fft_type == FFT_Type::C2R) {
         N0 = (N / dim_vec[rank - 1]) * (dim_vec[rank - 1] / 2 + 1) * 2;
         N1 = N;
+      } else {
+        N0 = 0;
+        N1 = 0;
       }
       this->dim[0] = N0;
       this->dim[1] = N1;
@@ -570,6 +628,9 @@ template <> class FFT<long double> : public FFT_Generic<long double, FFT<long do
       } else if (fft_type == FFT_Type::C2R) {
         N0 = (N / dim_vec[rank - 1]) * (dim_vec[rank - 1] / 2 + 1) * 2;
         N1 = N;
+      } else {
+        N0 = 0;
+        N1 = 0;
       }
       this->dim[0] = N0;
       this->dim[1] = N1;

+ 5 - 2
include/sctl/intrin_wrapper.hpp

@@ -21,6 +21,9 @@
 #include <immintrin.h>
 #endif
 
+// TODO: Check alignment which SCTL_MEMDEBUG is defined
+// TODO: Replace pointers with iterators
+
 namespace SCTL_NAMESPACE {
 
 template <class T> inline T zero_intrin() { return (T)0; }
@@ -43,7 +46,7 @@ template <class T> inline T cmplt_intrin(const T& a, const T& b) {
   T r = 0;
   uint8_t* r_ = reinterpret_cast<uint8_t*>(&r);
   if (a < b)
-    for (int i = 0; i < sizeof(T); i++) r_[i] = ~(uint8_t)0;
+    for (int i = 0; i < (int)sizeof(T); i++) r_[i] = ~(uint8_t)0;
   return r;
 }
 
@@ -52,7 +55,7 @@ template <class T> inline T and_intrin(const T& a, const T& b) {
   const uint8_t* a_ = reinterpret_cast<const uint8_t*>(&a);
   const uint8_t* b_ = reinterpret_cast<const uint8_t*>(&b);
   uint8_t* r_ = reinterpret_cast<uint8_t*>(&r);
-  for (int i = 0; i < sizeof(T); i++) r_[i] = a_[i] & b_[i];
+  for (int i = 0; i < (int)sizeof(T); i++) r_[i] = a_[i] & b_[i];
   return r;
 }
 

+ 90 - 17
include/sctl/sph_harm.hpp

@@ -126,6 +126,8 @@ template <class Real> class SphericalHarmonics{
      */
     static void StokesEvalDL(const Vector<Real>& S, SHCArrange arrange, Long p, const Vector<Real>& coord, bool interior, Vector<Real>& U);
 
+    static void StokesEvalKL(const Vector<Real>& S, SHCArrange arrange, Long p, const Vector<Real>& coord, const Vector<Real>& norm, bool interior, Vector<Real>& U);
+
 
     static void test_stokes() {
       int p = 6;
@@ -144,6 +146,7 @@ template <class Real> class SphericalHarmonics{
 
       Vector<Real> Fcoeff(dof*(p+1)*(p+2));
       for (Long i=0;i<Fcoeff.Dim();i++) Fcoeff[i]=i+1;
+      Fcoeff = 0; Fcoeff[2] = 1;
       print_coeff(Fcoeff);
 
       Vector<Real> Fgrid;
@@ -152,9 +155,9 @@ template <class Real> class SphericalHarmonics{
 
       const Vector<Real> CosTheta = LegendreNodes(Nt-1);
       const Vector<Real> LWeights = LegendreWeights(Nt-1);
-      auto stokes_evalSL = [&](const Vector<Real>& trg, Vector<Real>& Df) {
-        Df.ReInit(3);
-        Df=0;
+      auto stokes_evalSL = [&](const Vector<Real>& trg, Vector<Real>& Sf) {
+        Sf.ReInit(3);
+        Sf=0;
         Real s = 1/(8*const_pi<Real>());
         for (Long i=0;i<Nt;i++) {
           Real cos_theta = CosTheta[i];
@@ -183,15 +186,15 @@ template <class Real> class SphericalHarmonics{
 
             Real rdotf = dr[0]*f[0]+dr[1]*f[1]+dr[2]*f[2];
 
-            Df[0] += s*(f[0]*oor1 + dr[0]*rdotf*oor3) * qw;
-            Df[1] += s*(f[1]*oor1 + dr[1]*rdotf*oor3) * qw;
-            Df[2] += s*(f[2]*oor1 + dr[2]*rdotf*oor3) * qw;
+            Sf[0] += s*(f[0]*oor1 + dr[0]*rdotf*oor3) * qw;
+            Sf[1] += s*(f[1]*oor1 + dr[1]*rdotf*oor3) * qw;
+            Sf[2] += s*(f[2]*oor1 + dr[2]*rdotf*oor3) * qw;
           }
         }
       };
-      auto stokes_evalDL = [&](const Vector<Real>& trg, Vector<Real>& Df) {
-        Df.ReInit(3);
-        Df=0;
+      auto stokes_evalDL = [&](const Vector<Real>& trg, Vector<Real>& Sf) {
+        Sf.ReInit(3);
+        Sf=0;
         Real s = 6/(8*const_pi<Real>());
         for (Long i=0;i<Nt;i++) {
           Real cos_theta = CosTheta[i];
@@ -224,9 +227,69 @@ template <class Real> class SphericalHarmonics{
             Real rdotn = dr[0]*n[0]+dr[1]*n[1]+dr[2]*n[2];
             Real rdotf = dr[0]*f[0]+dr[1]*f[1]+dr[2]*f[2];
 
-            Df[0] += -s*dr[0]*rdotn*rdotf*oor5 * qw;
-            Df[1] += -s*dr[1]*rdotn*rdotf*oor5 * qw;
-            Df[2] += -s*dr[2]*rdotn*rdotf*oor5 * qw;
+            Sf[0] += -s*dr[0]*rdotn*rdotf*oor5 * qw;
+            Sf[1] += -s*dr[1]*rdotn*rdotf*oor5 * qw;
+            Sf[2] += -s*dr[2]*rdotn*rdotf*oor5 * qw;
+          }
+        }
+      };
+      auto stokes_evalKL = [&](const Vector<Real>& trg, const Vector<Real>& nor, Vector<Real>& Sf) {
+        Sf.ReInit(3);
+        Sf=0;
+        Real scal = 1/(8*const_pi<Real>());
+        for (Long i=0;i<Nt;i++) {
+          Real cos_theta = CosTheta[i];
+          Real sin_theta = sqrt(1-cos_theta*cos_theta);
+          for (Long j=0;j<Np;j++) {
+            Real cos_phi = cos(2*const_pi<Real>()*j/Np);
+            Real sin_phi = sin(2*const_pi<Real>()*j/Np);
+            Real qw = LWeights[i]*2*const_pi<Real>()/Np; // quadrature weights * area-element
+
+            Real f[3]; // source density
+            f[0] = Fgrid[(i*Np+j)*3+0];
+            f[1] = Fgrid[(i*Np+j)*3+1];
+            f[2] = Fgrid[(i*Np+j)*3+2];
+
+            Real x[3]; // source coordinates
+            x[0] = sin_theta*cos_phi;
+            x[1] = sin_theta*sin_phi;
+            x[2] = cos_theta;
+
+            Real dr[3];
+            dr[0] = trg[0] - x[0];
+            dr[1] = trg[1] - x[1];
+            dr[2] = trg[2] - x[2];
+
+            Real invr = 1 / sqrt(dr[0]*dr[0] + dr[1]*dr[1] + dr[2]*dr[2]);
+            Real invr2 = invr*invr;
+            Real invr3 = invr2*invr;
+            Real invr5 = invr2*invr3;
+
+            Real fdotr = dr[0]*f[0]+dr[1]*f[1]+dr[2]*f[2];
+
+            Real du[9];
+            du[0] = (                  fdotr*invr3 - 3*dr[0]*dr[0]*fdotr*invr5) * scal;
+            du[1] = ((dr[0]*f[1]-dr[1]*f[0])*invr3 - 3*dr[0]*dr[1]*fdotr*invr5) * scal;
+            du[2] = ((dr[0]*f[2]-dr[2]*f[0])*invr3 - 3*dr[0]*dr[2]*fdotr*invr5) * scal;
+
+            du[3] = ((dr[1]*f[0]-dr[0]*f[1])*invr3 - 3*dr[1]*dr[0]*fdotr*invr5) * scal;
+            du[4] = (                  fdotr*invr3 - 3*dr[1]*dr[1]*fdotr*invr5) * scal;
+            du[5] = ((dr[1]*f[2]-dr[2]*f[1])*invr3 - 3*dr[1]*dr[2]*fdotr*invr5) * scal;
+
+            du[6] = ((dr[2]*f[0]-dr[0]*f[2])*invr3 - 3*dr[2]*dr[0]*fdotr*invr5) * scal;
+            du[7] = ((dr[2]*f[1]-dr[1]*f[2])*invr3 - 3*dr[2]*dr[1]*fdotr*invr5) * scal;
+            du[8] = (                  fdotr*invr3 - 3*dr[2]*dr[2]*fdotr*invr5) * scal;
+
+            Real p = (2*fdotr*invr3) * scal;
+
+            Real K[9];
+            K[0] = du[0] + du[0] - p; K[1] = du[1] + du[3] - 0; K[2] = du[2] + du[6] - 0;
+            K[3] = du[3] + du[1] - 0; K[4] = du[4] + du[4] - p; K[5] = du[5] + du[7] - 0;
+            K[6] = du[6] + du[2] - 0; K[7] = du[7] + du[5] - 0; K[8] = du[8] + du[8] - p;
+
+            Sf[0] += (K[0]*nor[0] + K[1]*nor[1] + K[2]*nor[2]) * qw;
+            Sf[1] += (K[3]*nor[0] + K[4]*nor[1] + K[5]*nor[2]) * qw;
+            Sf[2] += (K[6]*nor[0] + K[7]*nor[1] + K[8]*nor[2]) * qw;
           }
         }
       };
@@ -234,10 +297,13 @@ template <class Real> class SphericalHarmonics{
       for (Long i = 0; i < 40; i++) { // Evaluate
         Real R0 = (0.01 + i/20.0);
 
-        Vector<Real> x(3);
+        Vector<Real> x(3), n(3);
         x[0] = drand48()-0.5;
         x[1] = drand48()-0.5;
         x[2] = drand48()-0.5;
+        n[0] = drand48()-0.5;
+        n[1] = drand48()-0.5;
+        n[2] = drand48()-0.5;
         Real R = sqrt<Real>(x[0]*x[0]+x[1]*x[1]+x[2]*x[2]);
         x[0] *= R0 / R;
         x[1] *= R0 / R;
@@ -245,19 +311,26 @@ template <class Real> class SphericalHarmonics{
 
         Vector<Real> Sf, Sf_;
         Vector<Real> Df, Df_;
+        Vector<Real> Kf, Kf_;
         StokesEvalSL(Fcoeff, sctl::SHCArrange::ROW_MAJOR, p, x, R0<1, Sf);
         StokesEvalDL(Fcoeff, sctl::SHCArrange::ROW_MAJOR, p, x, R0<1, Df);
+        StokesEvalKL(Fcoeff, sctl::SHCArrange::ROW_MAJOR, p, x, n, R0<1, Kf);
         stokes_evalSL(x, Sf_);
         stokes_evalDL(x, Df_);
+        stokes_evalKL(x, n, Kf_);
 
         auto errSL = (Sf-Sf_)/(Sf+0.01);
         auto errDL = (Df-Df_)/(Df+0.01);
+        auto errKL = (Kf-Kf_)/(Kf+0.01);
         for (auto& x:errSL) x=log(fabs(x))/log(10);
         for (auto& x:errDL) x=log(fabs(x))/log(10);
-        std::cout<<"R = "<<(0.01 + i/20.0)<<";   SL-error = ";
-        std::cout<<errSL;
-        std::cout<<"R = "<<(0.01 + i/20.0)<<";   DL-error = ";
-        std::cout<<errDL;
+        for (auto& x:errKL) x=log(fabs(x))/log(10);
+        //std::cout<<"R = "<<(0.01 + i/20.0)<<";   SL-error = ";
+        //std::cout<<errSL;
+        //std::cout<<"R = "<<(0.01 + i/20.0)<<";   DL-error = ";
+        //std::cout<<errDL;
+        std::cout<<"R = "<<(0.01 + i/20.0)<<";   KL-error = ";
+        std::cout<<errKL;
       }
       Clear();
     }

+ 233 - 2
include/sctl/sph_harm.txx

@@ -653,7 +653,7 @@ template <class Real> void SphericalHarmonics<Real>::StokesEvalSL(const Vector<R
     assert(SHBasis.Dim(0) == N * COORD_DIM);
   }
 
-  Matrix<Real> StokesOp(SHBasis.Dim(0), SHBasis.Dim(1));
+  Matrix<Real> StokesOp(N * COORD_DIM, COORD_DIM * M);
   for (Long i = 0; i < N; i++) { // Set StokesOp
     for (Long m = 0; m <= p0; m++) {
       for (Long n = m; n <= p0; n++) {
@@ -814,7 +814,7 @@ template <class Real> void SphericalHarmonics<Real>::StokesEvalDL(const Vector<R
     assert(SHBasis.Dim(0) == N * COORD_DIM);
   }
 
-  Matrix<Real> StokesOp(SHBasis.Dim(0), SHBasis.Dim(1));
+  Matrix<Real> StokesOp(N * COORD_DIM, COORD_DIM * M);
   for (Long i = 0; i < N; i++) { // Set StokesOp
     for (Long m = 0; m <= p0; m++) {
       for (Long n = m; n <= p0; n++) {
@@ -937,6 +937,237 @@ template <class Real> void SphericalHarmonics<Real>::StokesEvalDL(const Vector<R
   }
 }
 
+template <class Real> void SphericalHarmonics<Real>::StokesEvalKL(const Vector<Real>& S, SHCArrange arrange, Long p0, const Vector<Real>& coord, const Vector<Real>& norm, bool interior, Vector<Real>& X) {
+  Long M = (p0+1) * (p0+1);
+
+  Long dof;
+  Matrix<Real> B1;
+  { // Set B1, dof
+    Vector<Real> B0;
+    SHCArrange1(S, arrange, p0, B0);
+    dof = B0.Dim() / M / COORD_DIM;
+    assert(B0.Dim() == dof * COORD_DIM * M);
+
+    B1.ReInit(dof, COORD_DIM * M);
+    Vector<Real> B1_(B1.Dim(0) * B1.Dim(1), B1.begin(), false);
+    SHCArrange0(B0, p0, B1_, SHCArrange::COL_MAJOR_NONZERO);
+  }
+  assert(B1.Dim(1) == COORD_DIM * M);
+  assert(B1.Dim(0) == dof);
+
+  Long N = coord.Dim() / COORD_DIM;
+  assert(coord.Dim() == N * COORD_DIM);
+
+  Matrix<Real> SHBasis;
+  Vector<Real> R, cos_theta_phi;
+  { // Set R, SHBasis
+    R.ReInit(N);
+    cos_theta_phi.ReInit(2 * N);
+    for (Long i = 0; i < N; i++) { // Set R, cos_theta_phi
+      ConstIterator<Real> x = coord.begin() + i * COORD_DIM;
+      R[i] = sqrt<Real>(x[0]*x[0] + x[1]*x[1] + x[2]*x[2]);
+      cos_theta_phi[i * 2 + 0] = x[2] / R[i];
+      cos_theta_phi[i * 2 + 1] = atan2(x[1], x[0]); // TODO: works only for float and double
+    }
+    SHBasisEval(p0, cos_theta_phi, SHBasis);
+    assert(SHBasis.Dim(1) == M);
+    assert(SHBasis.Dim(0) == N);
+  }
+
+  Matrix<Real> StokesOp(N * COORD_DIM, COORD_DIM * M);
+  for (Long i = 0; i < N; i++) { // Set StokesOp
+    StaticArray<Real, COORD_DIM> norm0;
+    Real cos_theta, sin_theta, cos_phi, sin_phi;
+    { // Set cos_theta, sin_theta, cos_phi, sin_phi
+      cos_theta = cos_theta_phi[i * 2 + 0];
+      sin_theta = sqrt<Real>(1 - cos_theta * cos_theta);
+      cos_phi = cos(cos_theta_phi[i * 2 + 1]);
+      sin_phi = sin(cos_theta_phi[i * 2 + 1]);
+    }
+    { // Set norm0 <-- Q^t * norm
+      StaticArray<Real,9> Q;
+      { // Set Q
+        Q[0] = sin_theta*cos_phi; Q[1] = sin_theta*sin_phi; Q[2] = cos_theta;
+        Q[3] = cos_theta*cos_phi; Q[4] = cos_theta*sin_phi; Q[5] =-sin_theta;
+        Q[6] =          -sin_phi; Q[7] =           cos_phi; Q[8] =         0;
+      }
+      StaticArray<Real,COORD_DIM> in;
+      in[0] = norm[i * COORD_DIM + 0];
+      in[1] = norm[i * COORD_DIM + 1];
+      in[2] = norm[i * COORD_DIM + 2];
+      norm0[0] = Q[0] * in[0] + Q[1] * in[1] + Q[2] * in[2];
+      norm0[1] = Q[3] * in[0] + Q[4] * in[1] + Q[5] * in[2];
+      norm0[2] = Q[6] * in[0] + Q[7] * in[1] + Q[8] * in[2];
+    }
+
+    Complex<Real> imag(0,1);
+    Complex<Real> exp_iphi(cos_phi, sin_phi);
+    Complex<Real> exp_iphi_conj(cos_phi, -sin_phi);
+    Real cot_theta = cos_theta / sin_theta;
+    Real csc_theta = 1 / sin_theta;
+    Real cos_2theta = 2 * cos_theta * cos_theta - 1;
+
+    for (Long m = 0; m <= p0; m++) {
+      for (Long n = m; n <= p0; n++) {
+        auto read_coeff = [&](Long n, Long m) {
+          Complex<Real> c;
+          if (0 <= m && m <= n && n <= p0) {
+            Long idx = (2 * p0 - m + 2) * m - (m ? p0+1 : 0) + n;
+            c.real = SHBasis[i][idx];
+            if (m) {
+              idx += (p0+1-m);
+              c.imag = SHBasis[i][idx];
+            }
+          }
+          return c;
+        };
+        auto write_coeff = [&](Complex<Real> c, Long n, Long m, Long k0, Long k1) {
+          if (0 <= m && m <= n && n <= p0 && 0 <= k0 && k0 < COORD_DIM && 0 <= k1 && k1 < COORD_DIM) {
+            Long idx = (2 * p0 - m + 2) * m - (m ? p0+1 : 0) + n;
+            StokesOp[i * COORD_DIM + k1][k0 * M + idx] = c.real;
+            if (m) {
+              idx += (p0+1-m);
+              StokesOp[i * COORD_DIM + k1][k0 * M + idx] = c.imag;
+            }
+          }
+        };
+
+        auto Ynm0 = read_coeff(n, m + 0);
+        auto Ynm1 = read_coeff(n, m + 1);
+        auto Ynm2 = read_coeff(n, m + 2);
+
+        Complex<Real> KV[COORD_DIM][COORD_DIM];
+        Complex<Real> KW[COORD_DIM][COORD_DIM];
+        Complex<Real> KX[COORD_DIM][COORD_DIM];
+        if (interior) {
+          KV[0][0] = 0;
+          KV[0][1] = 0;
+          KV[0][2] = 0;
+          KV[1][0] = 0;
+          KV[1][1] = 0;
+          KV[1][2] = 0;
+          KV[2][0] = 0;
+          KV[2][1] = 0;
+          KV[2][2] = 0;
+
+          KW[0][0] = 0;
+          KW[0][1] = 0;
+          KW[0][2] = 0;
+          KW[1][0] = 0;
+          KW[1][1] = 0;
+          KW[1][2] = 0;
+          KW[2][0] = 0;
+          KW[2][1] = 0;
+          KW[2][2] = 0;
+
+          KX[0][0] = 0;
+          KX[0][1] = 0;
+          KX[0][2] = 0;
+          KX[1][0] = 0;
+          KX[1][1] = 0;
+          KX[1][2] = 0;
+          KX[2][0] = 0;
+          KX[2][1] = 0;
+          KX[2][2] = 0;
+        } else {
+          Real r = R[i];
+
+          KV[0][0] =  (2*n*(n*n+3*n+2)*pow<Real>(r,-n-3)*Ynm0) / (4*n*n+8*n+3);
+          KW[0][0] = -(n*pow<Real>(r,-n-3)*(2*n*n*n*(r*r-1) + n*n*(7*r*r-5) + n*(r*r-1) - r*r + 2)*Ynm0) / (4*n*n-1);
+          KX[0][0] =  0;
+
+          KV[0][1] = -(2*n*(n+2)*exp_iphi_conj*pow<Real>(r,-n-3)*(sqrt<Real>(-m*m-m+n*n+n)*Ynm1 + m*exp_iphi*cot_theta*Ynm0)) / (4*n*n+8*n+3);
+          KW[0][1] =  (exp_iphi_conj*pow<Real>(r,-n-3)*(2*n*n*n*(r*r-1) + n*n*(r*r-3) - 2*n*(r*r-1) - r*r))*(sqrt<Real>(-m*m-m+n*n+n)*Ynm1 + m*exp_iphi*cot_theta*Ynm0) / (4*n*n-1);
+          KX[0][1] =  (imag*m*(n+2)*pow<Real>(r,-n-2)*csc_theta*Ynm0) / (2*n+1);
+
+          KV[0][2] = -(2*imag*m*n*(n+2)*pow<Real>(r,-n-3)*csc_theta*Ynm0) / (4*n*n+8*n+3);
+          KW[0][2] =  (imag*m*pow<Real>(r,-n-3)*(2*n*n*n*(r*r-1) + n*n*(r*r-3) - 2*n*(r*r-1) - r*r)*csc_theta*Ynm0) / (4*n*n-1);
+          KX[0][2] =  (pow<Real>(r,-n-3)*(-m*(n+2)*r*cot_theta*Ynm0 - (n+2)*exp_iphi_conj*r*sqrt<Real>(-m*(m+1) + n*n + n)*Ynm1)) / (2*n+1);
+
+          KV[1][0] = -(2*n*(n+2)*exp_iphi_conj*pow<Real>(r,-n-3)*(sqrt<Real>(-m*m-m+n*n+n)*Ynm1 + m*exp_iphi*cot_theta*Ynm0)) / (4*n*n+8*n+3);
+          KW[1][0] =  (exp_iphi_conj*pow<Real>(r,-n-3)*(2*n*n*n*(r*r-1) + n*n*(r*r-3) - 2*n*(r*r-1) - r*r))*(sqrt<Real>(-m*m-m+n*n+n)*Ynm1 + m*exp_iphi*cot_theta*Ynm0) / (4*n*n-1);
+          KX[1][0] =  (imag*m*(n+2)*pow<Real>(r,-n-2)*csc_theta*Ynm0) / (2*n+1);
+
+          KV[1][1] =  (2*(2*m+1)*n*exp_iphi*sqrt<Real>(-m*m-m+n*n+n)*cot_theta*Ynm1 - 2*n*exp_iphi*exp_iphi*(-m*m*cot_theta*cot_theta+m*csc_theta*csc_theta+n+1)*Ynm0 + 2*n*sqrt<Real>(m*m*m*m + 4*m*m*m + m*m*(-2*n*n-2*n+5) + m*(-4*n*n-4*n+2) + n*(n*n*n+2*n*n-n-2))*Ynm2) * (exp_iphi_conj*exp_iphi_conj*pow<Real>(r,-n-3)) / (4*n*n+8*n+3);
+          KW[1][1] =  (Ynm0*((m-1)*m*exp_iphi*exp_iphi*(2*n*n-(n-2)*(2*n+1)*r*r-n)*csc_theta*csc_theta - exp_iphi*exp_iphi*((n-2)*(2*n+1)*r*r*(n-m*m)+n*(2*n-1)*(m*m+n+1))) + sqrt<Real>((m-n)*(m-n+1)*(m+n+1)*(m+n+2))*(2*n*n-(n-2)*(2*n+1)*r*r-n)*Ynm2 + (2*m+1)*exp_iphi*sqrt<Real>(-m*(m+1)+n*n+n)*(2*n*n-(n-2)*(2*n+1)*r*r-n)*cot_theta*Ynm1) * (exp_iphi_conj*exp_iphi_conj*pow<Real>(r,-n-3)) / (4*n*n-1);
+          KX[1][1] = -(sqrt<Real>(-m*m-m+n*n+n)*Ynm1 + (m-1)*exp_iphi*cot_theta*Ynm0)*(2*imag*m*exp_iphi_conj*pow<Real>(r,-n-2)*csc_theta) / (2*n+1);
+
+          KV[1][2] =  (2*imag*m*n*exp_iphi_conj*pow<Real>(r,-n-3)*csc_theta*(sqrt<Real>(-m*m-m+n*n+n)*Ynm1 + (m-1)*exp_iphi*cot_theta*Ynm0)) / (4*n*n+8*n+3);
+          KW[1][2] = -(imag*m*exp_iphi_conj*pow<Real>(r,-n-3)*(-2*n*n + (n-2)*(2*n+1)*r*r + n)*csc_theta)*(sqrt<Real>(-m*(m+1) + n*n + n)*Ynm1 + (m-1)*exp_iphi*cot_theta*Ynm0) / (4*n*n-1);
+          KX[1][2] =  (4*m*exp_iphi*sqrt<Real>(-m*(m+1)+n*n+n)*cot_theta*Ynm1 + 2*sqrt<Real>((m-n)*(m-n+1)*(m+n+1)*(m+n+2))*Ynm2 + (m-1)*m*exp_iphi*exp_iphi*(cos_2theta+3)*csc_theta*csc_theta*Ynm0)*(exp_iphi_conj*exp_iphi_conj*pow<Real>(r,-n-2)) / (2*(2*n+1));
+
+          KV[2][0] = -(2*imag*m*n*(n+2)*pow<Real>(r,-n-3)*csc_theta*Ynm0) / (4*n*n+8*n+3);
+          KW[2][0] =  (imag*m*pow<Real>(r,-n-3)*(2*n*n*n*(r*r-1) + n*n*(r*r-3) - 2*n*(r*r-1) - r*r)*csc_theta*Ynm0) / (4*n*n-1);
+          KX[2][0] =  (pow<Real>(r,-n-3)*(-m*(n+2)*r*cot_theta*Ynm0 - (n+2)*exp_iphi_conj*r*sqrt<Real>(-m*(m+1)+n*n+n)*Ynm1)) / (2*n+1);
+
+          KV[2][1] =  (2*imag*m*n*exp_iphi_conj*pow<Real>(r,-n-3)*csc_theta*(sqrt<Real>(-m*m-m+n*n+n)*Ynm1 + (m-1)*exp_iphi*cot_theta*Ynm0)) / (4*n*n+8*n+3);
+          KW[2][1] = -(imag*m*exp_iphi_conj*pow<Real>(r,-n-3)*(-2*n*n + (n-2)*(2*n+1)*r*r + n)*csc_theta)*(sqrt<Real>(-m*(m+1)+n*n+n)*Ynm1 + (m-1)*exp_iphi*cot_theta*Ynm0) / (4*n*n-1);
+          KX[2][1] =  (4*m*exp_iphi*sqrt<Real>(-m*(m+1)+n*n+n)*cot_theta*Ynm1 + 2*sqrt<Real>((m-n)*(m-n+1)*(m+n+1)*(m+n+2))*Ynm2 + (m-1)*m*exp_iphi*exp_iphi*(cos_2theta+3)*csc_theta*csc_theta*Ynm0)*(exp_iphi_conj*exp_iphi_conj*pow<Real>(r,-n-2)) / (2*(2*n+1));
+
+          KV[2][2] =  (2*n*sqrt<Real>(-m*m-m+n*n+n)*cot_theta*Ynm1 - 2*n*exp_iphi*(m*m*csc_theta*csc_theta-m*cot_theta*cot_theta+n+1)*Ynm0)*(exp_iphi_conj*pow<Real>(r,-n-3)) / (4*n*n+8*n+3);
+          KW[2][2] =  (-sqrt<Real>(-m*(m+1)+n*n+n)*(-2*n*n + (n-2)*(2*n+1)*r*r + n)*cot_theta*Ynm1 + Ynm0*(m*exp_iphi*(-2*n*n + (n-2)*(2*n+1)*r*r + n)*((m-1)*csc_theta*csc_theta+1) - n*exp_iphi*(2*n*n + (n-2)*(2*n+1)*r*r + n - 1)))*(exp_iphi_conj*pow<Real>(r,-n-3)) / (4*n*n-1);
+          KX[2][2] =  (sqrt<Real>(-m*(m+1)+n*n+n)*Ynm1 + (m-1)*exp_iphi*cot_theta*Ynm0)*(2*imag*m*exp_iphi_conj*pow<Real>(r,-n-2)*csc_theta) / (2*n+1);
+        }
+
+        Complex<Real> SVr, SVt, SVp;
+        SVr = KV[0][0] * norm0[0] + KV[0][1] * norm0[1] + KV[0][2] * norm0[2];
+        SVt = KV[1][0] * norm0[0] + KV[1][1] * norm0[1] + KV[1][2] * norm0[2];
+        SVp = KV[2][0] * norm0[0] + KV[2][1] * norm0[1] + KV[2][2] * norm0[2];
+
+        Complex<Real> SWr, SWt, SWp;
+        SWr = KW[0][0] * norm0[0] + KW[0][1] * norm0[1] + KW[0][2] * norm0[2];
+        SWt = KW[1][0] * norm0[0] + KW[1][1] * norm0[1] + KW[1][2] * norm0[2];
+        SWp = KW[2][0] * norm0[0] + KW[2][1] * norm0[1] + KW[2][2] * norm0[2];
+
+        Complex<Real> SXr, SXt, SXp;
+        SXr = KX[0][0] * norm0[0] + KX[0][1] * norm0[1] + KX[0][2] * norm0[2];
+        SXt = KX[1][0] * norm0[0] + KX[1][1] * norm0[1] + KX[1][2] * norm0[2];
+        SXp = KX[2][0] * norm0[0] + KX[2][1] * norm0[1] + KX[2][2] * norm0[2];
+
+        write_coeff(SVr, n, m, 0, 0);
+        write_coeff(SVt, n, m, 0, 1);
+        write_coeff(SVp, n, m, 0, 2);
+
+        write_coeff(SWr, n, m, 1, 0);
+        write_coeff(SWt, n, m, 1, 1);
+        write_coeff(SWp, n, m, 1, 2);
+
+        write_coeff(SXr, n, m, 2, 0);
+        write_coeff(SXt, n, m, 2, 1);
+        write_coeff(SXp, n, m, 2, 2);
+      }
+    }
+  }
+
+  { // Set X <-- Q * StokesOp * B1
+    if (X.Dim() != N * dof * COORD_DIM) X.ReInit(N * dof * COORD_DIM);
+    for (Long k0 = 0; k0 < N; k0++) {
+      StaticArray<Real,9> Q;
+      { // Set Q
+        Real cos_theta = cos_theta_phi[k0 * 2 + 0];
+        Real sin_theta = sqrt<Real>(1 - cos_theta * cos_theta);
+        Real cos_phi = cos(cos_theta_phi[k0 * 2 + 1]);
+        Real sin_phi = sin(cos_theta_phi[k0 * 2 + 1]);
+        Q[0] = sin_theta*cos_phi; Q[1] = sin_theta*sin_phi; Q[2] = cos_theta;
+        Q[3] = cos_theta*cos_phi; Q[4] = cos_theta*sin_phi; Q[5] =-sin_theta;
+        Q[6] =          -sin_phi; Q[7] =           cos_phi; Q[8] =         0;
+      }
+      for (Long k1 = 0; k1 < dof; k1++) { // Set X <-- Q * StokesOp * B1
+        StaticArray<Real,COORD_DIM> in;
+        for (Long j = 0; j < COORD_DIM; j++) {
+          in[j] = 0;
+          for (Long i = 0; i < COORD_DIM * M; i++) {
+            in[j] += B1[k1][i] * StokesOp[k0 * COORD_DIM + j][i];
+          }
+        }
+        X[(k0 * dof + k1) * COORD_DIM + 0] = Q[0] * in[0] + Q[3] * in[1] + Q[6] * in[2];
+        X[(k0 * dof + k1) * COORD_DIM + 1] = Q[1] * in[0] + Q[4] * in[1] + Q[7] * in[2];
+        X[(k0 * dof + k1) * COORD_DIM + 2] = Q[2] * in[0] + Q[5] * in[1] + Q[8] * in[2];
+      }
+    }
+  }
+}
+
 
 
 

+ 9 - 0
src/test.cpp

@@ -58,6 +58,12 @@ void TestMatrix() {
 }
 
 int main(int argc, char** argv) {
+  sctl::SphericalHarmonics<double>::test_stokes();
+  return;
+
+#ifdef SCTL_HAVE_MPI
+  MPI_Init(&argc, &argv);
+#endif
 
   // Dry run (profiling disabled)
   ProfileMemgr();
@@ -78,5 +84,8 @@ int main(int argc, char** argv) {
     // sctl::aligned_delete(A); // Show memory leak warning when commented
   }
 
+#ifdef SCTL_HAVE_MPI
+    MPI_Finalize();
+#endif
   return 0;
 }