Dhairya Malhotra 7 years ago
parent
commit
a87d305deb
4 changed files with 62 additions and 32 deletions
  1. 6 6
      include/sctl/comm.txx
  2. 52 6
      include/sctl/mem_mgr.hpp
  3. 0 20
      include/sctl/mem_mgr.txx
  4. 4 0
      include/sctl/sph_harm.txx

+ 6 - 6
include/sctl/comm.txx

@@ -632,7 +632,7 @@ template <class Type> void Comm::ScatterForward(Vector<Type>& data_, const Vecto
     StaticArray<Long, 2> loc_size;
     loc_size[0] = data_.Dim();
     loc_size[1] = recv_size;
-    Allreduce(loc_size, glb_size, 2, CommOp::SUM);
+    Allreduce<Long>(loc_size, glb_size, 2, CommOp::SUM);
     if (glb_size[0] == 0 || glb_size[1] == 0) return;  // Nothing to be done.
     data_dim = glb_size[0] / glb_size[1];
     SCTL_ASSERT(glb_size[0] == data_dim * glb_size[1]);
@@ -758,7 +758,7 @@ template <class Type> void Comm::ScatterReverse(Vector<Type>& data_, const Vecto
     loc_size[0] = data_.Dim();
     loc_size[1] = scatter_index_.Dim();
     loc_size[2] = recv_size;
-    Allreduce(loc_size, glb_size, 3, CommOp::SUM);
+    Allreduce<Long>(loc_size, glb_size, 3, CommOp::SUM);
     if (glb_size[0] == 0 || glb_size[1] == 0) return;  // Nothing to be done.
 
     SCTL_ASSERT(glb_size[0] % glb_size[1] == 0);
@@ -792,16 +792,16 @@ template <class Type> void Comm::ScatterReverse(Vector<Type>& data_, const Vecto
     StaticArray<Long, 2> loc_size;
     loc_size[0] = data_.Dim() / data_dim;
     loc_size[1] = scatter_index_.Dim();
-    Scan(loc_size, glb_rank, 2, CommOp::SUM);
-    Allreduce(loc_size, glb_size, 2, CommOp::SUM);
+    Scan<Long>(loc_size, glb_rank, 2, CommOp::SUM);
+    Allreduce<Long>(loc_size, glb_size, 2, CommOp::SUM);
     SCTL_ASSERT(glb_size[0] == glb_size[1]);
     glb_rank[0] -= loc_size[0];
     glb_rank[1] -= loc_size[1];
 
     Vector<Long> glb_scan0(npes + 1);
     Vector<Long> glb_scan1(npes + 1);
-    Allgather(glb_rank + 0, 1, glb_scan0.begin(), 1);
-    Allgather(glb_rank + 1, 1, glb_scan1.begin(), 1);
+    Allgather<Long>(glb_rank + 0, 1, glb_scan0.begin(), 1);
+    Allgather<Long>(glb_rank + 1, 1, glb_scan1.begin(), 1);
     glb_scan0[npes] = glb_size[0];
     glb_scan1[npes] = glb_size[1];
 

+ 52 - 6
include/sctl/mem_mgr.hpp

@@ -235,16 +235,15 @@ template <class ValueType> class Iterator : public ConstIterator<ValueType> {
   difference_type operator-(const ConstIterator<ValueType>& I) const { return static_cast<const ConstIterator<ValueType>&>(*this) - I; }
 };
 
-template <class ValueType, Long DIM> class StaticArray : public Iterator<ValueType> { // Warning: objects are not byte-copyable // TODO: Can be made by copyable by not inheriting Iterator and can also add memory header and padding to detect additional memory errors
+template <class ValueType, Long DIM> class StaticArray { // Warning: objects are not byte-copyable // TODO: Can be made by copyable by not inheriting Iterator and can also add memory header and padding to detect additional memory errors
+  typedef Long difference_type;
 
  public:
-  StaticArray();
-
-  StaticArray(const StaticArray&);
+  StaticArray() = default;
 
-  StaticArray& operator=(const StaticArray&);
+  StaticArray(const StaticArray&) = default;
 
-  ~StaticArray();
+  StaticArray& operator=(const StaticArray&) = default;
 
   StaticArray(std::initializer_list<ValueType> arr_) : StaticArray() {
     // static_assert(arr_.size() <= DIM, "too many initializer values"); // allowed in C++14
@@ -252,6 +251,53 @@ template <class ValueType, Long DIM> class StaticArray : public Iterator<ValueTy
     for (Long i = 0; i < (Long)arr_.size(); i++) (*this)[i] = arr_.begin()[i];
   }
 
+  ~StaticArray() = default;
+
+  // value_type* like operators
+  const ValueType& operator*() const { return *arr_; }
+
+  ValueType& operator*() { return *arr_; }
+
+  const ValueType* operator->() const { return arr_; }
+
+  ValueType* operator->() { return arr_; }
+
+  const ValueType& operator[](difference_type off) const { return arr_[off]; }
+
+  ValueType& operator[](difference_type off) { return arr_[off]; }
+
+  operator ConstIterator<ValueType>() const { return Iterator<ValueType>(arr_, DIM); }
+
+  operator Iterator<ValueType>() { return Iterator<ValueType>(arr_, DIM); }
+
+  // Arithmetic
+  ConstIterator<ValueType> operator+(difference_type i) const { return (ConstIterator<ValueType>)*this + i; }
+
+  Iterator<ValueType> operator+(difference_type i) { return (Iterator<ValueType>)*this + i; }
+
+  friend ConstIterator<ValueType> operator+(difference_type i, const StaticArray& right) { return i + (ConstIterator<ValueType>)right; }
+
+  friend Iterator<ValueType> operator+(difference_type i, StaticArray& right) { return i + (Iterator<ValueType>)right; }
+
+  ConstIterator<ValueType> operator-(difference_type i) const { return (ConstIterator<ValueType>)*this - i; }
+
+  Iterator<ValueType> operator-(difference_type i) { return (Iterator<ValueType>)*this - i; }
+
+  difference_type operator-(const ConstIterator<ValueType>& I) const { return (ConstIterator<ValueType>)*this - (ConstIterator<ValueType>)I; }
+
+  // Comparison operators
+  bool operator==(const ConstIterator<ValueType>& I) const { return (ConstIterator<ValueType>)*this == I; }
+
+  bool operator!=(const ConstIterator<ValueType>& I) const { return (ConstIterator<ValueType>)*this != I; }
+
+  bool operator< (const ConstIterator<ValueType>& I) const { return (ConstIterator<ValueType>)*this <  I; }
+
+  bool operator<=(const ConstIterator<ValueType>& I) const { return (ConstIterator<ValueType>)*this <= I; }
+
+  bool operator> (const ConstIterator<ValueType>& I) const { return (ConstIterator<ValueType>)*this >  I; }
+
+  bool operator>=(const ConstIterator<ValueType>& I) const { return (ConstIterator<ValueType>)*this >= I; }
+
  private:
 
   ValueType arr_[DIM];

+ 0 - 20
include/sctl/mem_mgr.txx

@@ -68,26 +68,6 @@ template <class ValueType> inline typename Iterator<ValueType>::reference Iterat
   return *(ValueType*)(this->base + this->offset + j * (Long)sizeof(ValueType));
 }
 
-template <class ValueType, Long DIM> inline StaticArray<ValueType, DIM>::StaticArray() {
-  //Iterator<ValueType>::operator=(aligned_new<ValueType>(DIM));
-  Iterator<ValueType>::operator=(Ptr2Itr<ValueType>(arr_, DIM));
-}
-
-template <class ValueType, Long DIM> inline StaticArray<ValueType, DIM>::~StaticArray() {
-  // aligned_delete<ValueType>(*this);
-}
-
-template <class ValueType, Long DIM> inline StaticArray<ValueType, DIM>::StaticArray(const StaticArray& I) {
-  //Iterator<ValueType>::operator=(aligned_new<ValueType>(DIM));
-  Iterator<ValueType>::operator=(Ptr2Itr<ValueType>(arr_, DIM));
-  for (Long i = 0; i < DIM; i++) (*this)[i] = I[i];
-}
-
-template <class ValueType, Long DIM> inline StaticArray<ValueType, DIM>& StaticArray<ValueType, DIM>::operator=(const StaticArray& I) {
-  for (Long i = 0; i < DIM; i++) (*this)[i] = I[i];
-  return *this;
-}
-
 #endif
 
 inline MemoryManager::MemoryManager(Long N) {

+ 4 - 0
include/sctl/sph_harm.txx

@@ -177,9 +177,11 @@ template <class Real> void SphericalHarmonics<Real>::SHC2Grid(const Vector<Real>
 
   Long M, N;
   { // Set M, N
+    M = 0;
     if (arrange == SHCArrange::ALL) M = 2*(p0+1)*(p0+1);
     if (arrange == SHCArrange::ROW_MAJOR) M = (p0+1)*(p0+2);
     if (arrange == SHCArrange::COL_MAJOR_NONZERO) M = (p0+1)*(p0+1);
+    if (M == 0) return;
     N = S.Dim() / M;
     assert(S.Dim() == N * M);
   }
@@ -414,9 +416,11 @@ template <class Real> void SphericalHarmonics<Real>::SHC2Pole(const Vector<Real>
 
   Long M, N;
   { // Set M, N
+    M = 0;
     if (arrange == SHCArrange::ALL) M = 2*(p0+1)*(p0+1);
     if (arrange == SHCArrange::ROW_MAJOR) M = (p0+1)*(p0+2);
     if (arrange == SHCArrange::COL_MAJOR_NONZERO) M = (p0+1)*(p0+1);
+    if (M == 0) return;
     N = S.Dim() / M;
     assert(S.Dim() == N * M);
   }